#!/usr/bin/env python3
"""Validate IWSLT voice-cloning submission WAV filename conventions.

Required naming format:
    {language}_{lineNumber}_{originalName}.wav

Example:
    ar_023_2023.acl-long.3.wav

Python usage example:
    python3 verify_submission_naming.py submissions/ar/v01 --language ar --source-file /path/to/arabic.txt --reference-dir /path/to/reference_wavs

This validator checks:
1. All discovered audio files are .wav files.
2. Filenames follow the expected pattern.
3. Optional line-number width constraints.
4. Filenames match a required language code.
5. Expected file-count verification from a required source text file.
6. Coverage checks (missing/duplicate/out-of-range line IDs).
7. Per-reference coverage: each reference audio must include all source lines.
"""

from __future__ import annotations

import argparse
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple


FILENAME_REGEX = re.compile(
    r"^(?P<lang>[A-Za-z]{2,3})_(?P<line>\d+)_(?P<original>.+)\.wav$"
)


@dataclass
class ValidationIssue:
    path: Path
    message: str


@dataclass
class ParsedFile:
    path: Path
    language: str
    line_id: int
    line_str: str
    original_name: str


def parse_args() -> argparse.Namespace:
    examples = """Examples:
    # Python example (recommended)
    python3 verify_submission_naming.py submissions/ar/v01 --language ar --source-file /path/to/arabic.txt --reference-dir /path/to/reference_wavs

    # French folder using source file line count as expected count
    python verify_submission_naming.py submissions/fr/v02 --language fr --source-file <path-to-source-file>.txt --reference-dir <path-to-reference-wavs>

    # Chinese folder with strict 3-digit line IDs
    python verify_submission_naming.py submissions/zh/v01 --language zh --source-file <path-to-source-file>.txt --reference-dir <path-to-reference-wavs> --strict-width --line-width 3

    # Source file is required and drives expected count/coverage checks
    python verify_submission_naming.py submissions/ar/v01 --language ar --source-file <path-to-source-file>.txt --reference-dir <path-to-reference-wavs>
"""

    parser = argparse.ArgumentParser(
        description="Validate submission filenames for IWSLT voice cloning.",
        epilog=examples,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        "submission_dir",
        nargs="?",
        default="output",
        help="Directory containing submission WAV files (recursive scan). Default: output",
    )
    parser.add_argument(
        "--language",
        type=str,
        required=True,
        help="Required language code expected in every filename (e.g., ar, fr, zh).",
    )
    parser.add_argument(
        "--line-width",
        type=int,
        default=3,
        help="Expected minimum digit width for line number. Default: 3",
    )
    parser.add_argument(
        "--source-file",
        type=str,
        required=True,
        help="Source text file used to derive expected file count and coverage (1-based lines).",
    )
    parser.add_argument(
        "--reference-dir",
        type=str,
        required=True,
        help="Directory containing reference WAV files whose stems must match output originalName.",
    )
    parser.add_argument(
        "--strict-width",
        action="store_true",
        help="Require line numbers to be exactly --line-width digits.",
    )
    return parser.parse_args()


def count_nonempty_lines(source_file: Path) -> int:
    with source_file.open("r", encoding="utf-8") as f:
        return sum(1 for line in f if line.strip())


def collect_audio_files(submission_dir: Path) -> List[Path]:
    return sorted([p for p in submission_dir.rglob("*") if p.is_file()])


def collect_reference_names(reference_dir: Path) -> List[str]:
    reference_wavs = sorted([p for p in reference_dir.rglob("*.wav") if p.is_file()])
    return [p.stem for p in reference_wavs]


def validate_filename(
    path: Path, line_width: int, strict_width: bool
) -> Tuple[Optional[ParsedFile], List[ValidationIssue]]:
    issues: List[ValidationIssue] = []

    if path.suffix.lower() != ".wav":
        issues.append(
            ValidationIssue(path=path, message="File extension is not .wav")
        )
        return None, issues

    match = FILENAME_REGEX.match(path.name)
    if not match:
        issues.append(
            ValidationIssue(
                path=path,
                message=(
                    "Filename does not match {language}_{lineNumber}_{originalName}.wav"
                ),
            )
        )
        return None, issues

    language = match.group("lang").lower()
    line_str = match.group("line")
    original_name = match.group("original")

    if strict_width and len(line_str) != line_width:
        issues.append(
            ValidationIssue(
                path=path,
                message=(
                    f"Line number width is {len(line_str)}, expected exactly {line_width}"
                ),
            )
        )
    elif not strict_width and len(line_str) < line_width:
        issues.append(
            ValidationIssue(
                path=path,
                message=(
                    f"Line number width is {len(line_str)}, expected at least {line_width}"
                ),
            )
        )

    if original_name.endswith(".wav"):
        issues.append(
            ValidationIssue(
                path=path,
                message="originalName must not include a trailing .wav",
            )
        )

    try:
        line_id = int(line_str)
    except ValueError:
        issues.append(
            ValidationIssue(path=path, message="Line number is not an integer")
        )
        return None, issues

    if line_id <= 0:
        issues.append(
            ValidationIssue(path=path, message="Line number must be >= 1")
        )

    parsed = ParsedFile(
        path=path,
        language=language,
        line_id=line_id,
        line_str=line_str,
        original_name=original_name,
    )
    return parsed, issues


def validate_coverage(
    parsed_files: List[ParsedFile], expected_lines: int
) -> List[ValidationIssue]:
    issues: List[ValidationIssue] = []
    by_language: Dict[str, List[ParsedFile]] = {}

    for item in parsed_files:
        by_language.setdefault(item.language, []).append(item)

    for language, items in sorted(by_language.items()):
        line_ids = [x.line_id for x in items]
        unique_line_ids = set(line_ids)

        if len(unique_line_ids) != len(line_ids):
            seen = set()
            duplicates = sorted({x for x in line_ids if x in seen or seen.add(x)})
            for dup in duplicates:
                issues.append(
                    ValidationIssue(
                        path=items[0].path,
                        message=f"Language '{language}' has duplicate line_id {dup}",
                    )
                )

        missing = [i for i in range(1, expected_lines + 1) if i not in unique_line_ids]
        for miss in missing[:20]:
            issues.append(
                ValidationIssue(
                    path=items[0].path,
                    message=f"Language '{language}' missing line_id {miss}",
                )
            )
        if len(missing) > 20:
            issues.append(
                ValidationIssue(
                    path=items[0].path,
                    message=(
                        f"Language '{language}' has {len(missing)} missing lines "
                        "(showing first 20)"
                    ),
                )
            )

        out_of_range = sorted(x for x in unique_line_ids if x > expected_lines)
        for bad in out_of_range[:20]:
            issues.append(
                ValidationIssue(
                    path=items[0].path,
                    message=(
                        f"Language '{language}' has line_id {bad} beyond source length {expected_lines}"
                    ),
                )
            )
        if len(out_of_range) > 20:
            issues.append(
                ValidationIssue(
                    path=items[0].path,
                    message=(
                        f"Language '{language}' has {len(out_of_range)} out-of-range line IDs "
                        "(showing first 20)"
                    ),
                )
            )

    return issues


def validate_per_reference_coverage(
    parsed_files: List[ParsedFile],
    expected_lines: int,
    language: str,
    reference_names: List[str],
) -> List[ValidationIssue]:
    issues: List[ValidationIssue] = []
    by_original: Dict[str, List[ParsedFile]] = {}

    for item in parsed_files:
        by_original.setdefault(item.original_name, []).append(item)

    expected_reference_set = set(reference_names)

    unexpected_references = sorted(set(by_original.keys()) - expected_reference_set)
    for original_name in unexpected_references:
        items = by_original[original_name]
        issues.append(
            ValidationIssue(
                path=items[0].path,
                message=(
                    f"Reference '{original_name}' not found in provided reference directory"
                ),
            )
        )

    for original_name in sorted(reference_names):
        items = by_original.get(original_name, [])
        if not items:
            issues.append(
                ValidationIssue(
                    path=Path("."),
                    message=(
                        f"Reference '{original_name}' is missing all {expected_lines} lines "
                        f"for language '{language}'"
                    ),
                )
            )
            continue

        line_ids = [x.line_id for x in items]
        unique_line_ids = set(line_ids)

        if len(unique_line_ids) != len(line_ids):
            seen = set()
            duplicates = sorted({x for x in line_ids if x in seen or seen.add(x)})
            for dup in duplicates:
                issues.append(
                    ValidationIssue(
                        path=items[0].path,
                        message=(
                            f"Reference '{original_name}' has duplicate line_id {dup} "
                            f"for language '{language}'"
                        ),
                    )
                )

        missing = [i for i in range(1, expected_lines + 1) if i not in unique_line_ids]
        if missing:
            for miss in missing[:20]:
                issues.append(
                    ValidationIssue(
                        path=items[0].path,
                        message=(
                            f"Reference '{original_name}' missing line_id {miss} "
                            f"for language '{language}'"
                        ),
                    )
                )
            if len(missing) > 20:
                issues.append(
                    ValidationIssue(
                        path=items[0].path,
                        message=(
                            f"Reference '{original_name}' has {len(missing)} missing lines "
                            f"for language '{language}' (showing first 20)"
                        ),
                    )
                )

        out_of_range = sorted(x for x in unique_line_ids if x > expected_lines)
        if out_of_range:
            for bad in out_of_range[:20]:
                issues.append(
                    ValidationIssue(
                        path=items[0].path,
                        message=(
                            f"Reference '{original_name}' has line_id {bad} beyond source length "
                            f"{expected_lines} for language '{language}'"
                        ),
                    )
                )
            if len(out_of_range) > 20:
                issues.append(
                    ValidationIssue(
                        path=items[0].path,
                        message=(
                            f"Reference '{original_name}' has {len(out_of_range)} out-of-range line IDs "
                            f"for language '{language}' (showing first 20)"
                        ),
                    )
                )

    return issues


def main() -> int:
    args = parse_args()
    submission_dir = Path(args.submission_dir)
    required_language = args.language.strip().lower()

    if not re.fullmatch(r"[a-z]{2,3}", required_language):
        print(
            "ERROR: --language must be 2-3 lowercase letters (for example: ar, fr, zh)"
        )
        return 2

    if not submission_dir.exists() or not submission_dir.is_dir():
        print(f"ERROR: Submission directory not found: {submission_dir}")
        return 2

    audio_files = collect_audio_files(submission_dir)
    if not audio_files:
        print(f"ERROR: No files found under {submission_dir}")
        return 2

    parsed_files: List[ParsedFile] = []
    issues: List[ValidationIssue] = []
    non_wav_count = 0

    for path in audio_files:
        parsed, file_issues = validate_filename(
            path=path,
            line_width=args.line_width,
            strict_width=args.strict_width,
        )
        issues.extend(file_issues)
        if path.suffix.lower() != ".wav":
            non_wav_count += 1
        if parsed is not None:
            parsed_files.append(parsed)

    language_mismatch = [p for p in parsed_files if p.language != required_language]
    for item in language_mismatch:
        issues.append(
            ValidationIssue(
                path=item.path,
                message=(
                    f"Language code mismatch: expected '{required_language}', found '{item.language}'"
                ),
            )
        )

    selected_language_files = [
        p for p in parsed_files if p.language == required_language
    ]

    reference_dir = Path(args.reference_dir)
    if not reference_dir.exists() or not reference_dir.is_dir():
        print(f"ERROR: Reference directory not found: {reference_dir}")
        return 2

    reference_names = collect_reference_names(reference_dir)
    if not reference_names:
        print(f"ERROR: No .wav files found in reference directory: {reference_dir}")
        return 2

    source_path = Path(args.source_file)
    if not source_path.exists() or not source_path.is_file():
        print(f"ERROR: Source file not found: {source_path}")
        return 2

    expected_lines: int = count_nonempty_lines(source_path)
    if expected_lines <= 0:
        print(f"ERROR: Source file has no non-empty lines: {source_path}")
        return 2

    expected_count = expected_lines
    actual_count = len(selected_language_files)
    reference_count = len(reference_names)
    required_total = expected_count * reference_count

    if actual_count != required_total:
        issues.append(
            ValidationIssue(
                path=submission_dir,
                message=(
                    f"File count mismatch for language '{required_language}': "
                    f"expected {required_total} ({expected_count} lines x {reference_count} reference audios), "
                    f"found {actual_count}"
                ),
            )
        )

    issues.extend(
        validate_per_reference_coverage(
            selected_language_files,
            expected_lines,
            required_language,
            reference_names,
        )
    )

    print("=" * 72)
    print("IWSLT Submission Filename Validation")
    print("=" * 72)
    print(f"Submission dir: {submission_dir}")
    print(f"Language code : {required_language}")
    print(f"Files scanned  : {len(audio_files)}")
    print(f"Non-WAV files : {non_wav_count}")
    print(f"Valid {required_language} files: {len(selected_language_files)}")
    print(f"Reference audios: {reference_count}")
    print(f"Expected count: {required_total} ({expected_count} lines x {reference_count} references)")
    print(f"Source lines   : {expected_lines}")
    print(f"Source file    : {source_path}")
    print(f"Reference dir  : {reference_dir}")
    print("-" * 72)

    if issues:
        print(f"FAILED: {len(issues)} issue(s) found")
        print("How to fix:")
        print(
            "  - Use filename format {language}_{lineNumber}_{originalName}.wav"
        )
        print(
            f"  - Use language '{required_language}' in every filename prefix"
        )
        print("  - Keep line numbers zero-padded to at least --line-width digits")
        print("  - For each reference audio, include all source line IDs")
        for issue in issues[:100]:
            print(f"  - {issue.path}: {issue.message}")
        if len(issues) > 100:
            print(f"  ... and {len(issues) - 100} more issue(s)")
        return 1

    print("PASSED: filenames, language code, and file count are valid")
    return 0


if __name__ == "__main__":
    sys.exit(main())
