import os
import sys
import csv
import argparse
import hashlib
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict

class UserDeclinedError(Exception):
    """Raised when the user declines a confirmation prompt."""
    pass

def ask_yes_no(prompt: str) -> bool:
    while True:
        answer = input(f"{prompt} ? [y/n]: ").strip().lower()
        if answer in ("y", "n"):
            return answer == "y"
        print("Please answer with y or n.")

def print_header(title: str) -> None:
    print()
    print("=" * len(title))
    print(title)
    print("=" * len(title))

def count_files(root: Path, report_file: Path | None = None) -> int:
    total_files = 0
    for current_root, dirs, files in os.walk(root):
        for name in files:
            path = Path(current_root) / name
            try:
                if report_file is not None and path.resolve() == report_file.resolve():
                    continue
            except Exception:
                pass
            try:
                if path.is_symlink():
                    continue
                if not path.is_file():
                    continue
                total_files += 1
            except Exception:
                continue
    return total_files

def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str:
    h = hashlib.sha256()
    with path.open("rb") as f:
        while True:
            chunk = f.read(chunk_size)
            if not chunk:
                break
            h.update(chunk)
    return h.hexdigest()

def find_duplicate_groups(
    root: Path,
    total_files: int,
    report_file: Path | None = None,
) -> tuple[dict, int, int]:
    size_groups: dict[int, list[Path]] = defaultdict(list)
    hash_groups: dict[str, list[Path]] = defaultdict(list)
    file_sizes: dict[Path, int] = {}

    total_files_scanned = 0
    total_bytes_scanned = 0

    # Step 1: group files by size
    with tqdm(total=total_files, desc="Scanning files", unit="file") as pbar:
        for current_root, dirs, files in os.walk(root):
            for name in files:
                path = Path(current_root) / name
                try:
                    if report_file is not None and path.resolve() == report_file.resolve():
                        continue
                except Exception:
                    pass
                try:
                    if path.is_symlink():
                        continue
                    if not path.is_file():
                        continue
                    stat = path.stat()
                    size_groups[stat.st_size].append(path)
                    file_sizes[path] = stat.st_size
                    total_files_scanned += 1
                    total_bytes_scanned += stat.st_size
                except Exception as e:
                    print(f"Skipped: {path} ({e})")
                finally:
                    pbar.update(1)

    # Step 2: hash only files with identical sizes
    for size, paths in size_groups.items():
        if len(paths) < 2:
            continue
        for path in paths:
            try:
                digest = sha256_file(path)
                hash_groups[digest].append(path)
            except Exception as e:
                print(f"Hash failed: {path} ({e})")

    # Step 3: build duplicate groups
    duplicate_groups: dict[str, list[Path]] = {}
    for digest, paths in hash_groups.items():
        if len(paths) >= 2:
            duplicate_groups[digest] = sorted(paths)
    return duplicate_groups, total_files_scanned, total_bytes_scanned

def print_duplicate_groups(duplicate_groups: dict[str, list[Path]]) -> None:
    print_header("Duplicate groups found")
    if not duplicate_groups:
        print("No duplicate files found.")
        return
    for index, (digest, paths) in enumerate(sorted(duplicate_groups.items()), start=1):
        print(f"Group {index}")
        for i, path in enumerate(paths):
            marker = " [keep]" if i == 0 else ""
            print(f"{path}{marker}")
        print()

def delete_duplicates(
    duplicate_groups: dict[str, list[Path]],
    report_file: Path | None = None,
) -> tuple[list[list[str]], int, int, int]:
    csv_rows: list[list[str]] = []
    deleted_count = 0
    failed_count = 0
    reclaimed_bytes = 0

    for group_id, (digest, paths) in enumerate(sorted(duplicate_groups.items()), start=1):
        keep_path = paths[0]
        csv_rows.append([str(group_id), str(keep_path), digest, "kept"])

        for duplicate_path in paths[1:]:
            try:
                if report_file is not None and duplicate_path.resolve() == report_file.resolve():
                    csv_rows.append([str(group_id), str(duplicate_path), digest, "skipped_report_file"])
                    continue
                size = duplicate_path.stat().st_size
                duplicate_path.unlink()
                deleted_count += 1
                reclaimed_bytes += size
                csv_rows.append([str(group_id), str(duplicate_path), digest, "deleted"])
                print(f"Deleted: {duplicate_path}")
            except Exception as e:
                failed_count += 1
                csv_rows.append([str(group_id), str(duplicate_path), digest, f"delete_failed: {e}"])
                print(f"Failed:  {duplicate_path} ({e})")
    return csv_rows, deleted_count, failed_count, reclaimed_bytes

def write_csv_report(report_file: Path, rows: list[list[str]]) -> None:
    with report_file.open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["group_id", "path", "hash", "action"])
        writer.writerows(rows)

def main() -> int:
    try:
        # Parse command-line arguments
        parser = argparse.ArgumentParser(description="Scan a directory for duplicate files and optionally delete them.")
        parser.add_argument("--root", type=Path, required=True, help="Root directory to scan for duplicates")
        parser.add_argument("--report-file", type=Path, help="Optional CSV report file to write")
        args = parser.parse_args()
        root: Path = args.root
        report_file: Path | None = args.report_file
        if not root.exists():
            raise FileNotFoundError(f"Root directory not found: {root}")
        if not root.is_dir():
            raise NotADirectoryError(f"Path is not a directory: {root}")

        # Scanning for duplicated files
        print_header("Scanning files")
        total_files = count_files(root=root, report_file=report_file)
        print(f"Eligible files found   : {total_files}")
        duplicate_groups, total_files_scanned, total_bytes_scanned = find_duplicate_groups(
            root=root,
            total_files=total_files,
            report_file=report_file,
        )
        print(f"Root directory scanned : {root}")
        print(f"Total files scanned    : {total_files_scanned}")
        print(f"Total size scanned     : {total_bytes_scanned} bytes")
        print(f"Total size scanned     : {total_bytes_scanned / (1024 * 1024):.2f} MiB")

        # Display duplicated groups
        print_duplicate_groups(duplicate_groups)
        if not duplicate_groups:
            if report_file:
                write_csv_report(report_file, [])
                print()
                print(f"CSV report written to: {report_file}")
            return 0

        # Compute duplicated files size
        total_groups = len(duplicate_groups)
        total_duplicates = sum(len(paths) - 1 for paths in duplicate_groups.values())
        total_reclaimable = sum(
            path.stat().st_size
            for paths in duplicate_groups.values()
            for path in paths[1:]
            if path.exists()
        )
        print_header("Summary before deletion")
        print(f"Duplicate groups       : {total_groups}")
        print(f"Files that can delete  : {total_duplicates}")
        print(f"Reclaimable bytes      : {total_reclaimable}")
        print(f"Reclaimable MiB        : {total_reclaimable / (1024 * 1024):.2f}")

        # Ask user to pursue
        if not ask_yes_no("Do you want to continue"):
            raise UserDeclinedError("Operation cancelled by user.")

        # Delete duplicated files
        print_header("Deleting duplicate files")
        csv_rows, deleted_count, failed_count, reclaimed_bytes = delete_duplicates(duplicate_groups=duplicate_groups, report_file=report_file)
        if report_file:
            write_csv_report(report_file, csv_rows)
            print()
            print(f"CSV report written to: {report_file}")

        print_header("Deletion summary")
        print(f"Deleted files         : {deleted_count}")
        print(f"Failed deletions      : {failed_count}")
        print(f"Reclaimed bytes       : {reclaimed_bytes}")
        print(f"Reclaimed MiB         : {reclaimed_bytes / (1024 * 1024):.2f}")
        return 0

    except KeyboardInterrupt:
        print("\nOperation interrupted by user.")
        return 130
    except UserDeclinedError as e:
        print(e)
        return 0
    except Exception as e:
        print(f"Unexpected error: {e}")
        return 1

if __name__ == "__main__":
    sys.exit(main())