#!/usr/bin/env python3
"""Unified GUI voucher scraper.

Features:
- Pull one URL
- Pull by search term
- Pull all accounts
- Run 1..25 workers with live monitoring
- Edit voucher_scanned_accounts.json
- Save voucher results directly into the ALL folder as markdown files

Runs on Linux/Windows with Python + Selenium + Firefox/Geckodriver.
"""

from __future__ import annotations

import json
import queue
import threading
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Optional, Sequence
from urllib.parse import urljoin

import tkinter as tk
from tkinter import filedialog, messagebox, ttk
from tkinter.scrolledtext import ScrolledText

from selenium.webdriver.common.by import By

try:
    from voucher_scraper_core import (
        _safe_fs_component,
        build_driver,
        drop_hidden_product_columns,
        extract_all_datatable_rows,
        extract_program_summary_block,
        filter_accounts_term,
        find_table_by_headers,
        get_account_name_from_summary_span,
        get_account_program_links,
        get_all_accounts_from_results,
        login,
        parse_company_information_from_text,
    )
except ModuleNotFoundError:
    from Voucher_List_Folder.voucher_scraper_core import (
        _safe_fs_component,
        build_driver,
        drop_hidden_product_columns,
        extract_all_datatable_rows,
        extract_program_summary_block,
        filter_accounts_term,
        find_table_by_headers,
        get_account_name_from_summary_span,
        get_account_program_links,
        get_all_accounts_from_results,
        login,
        parse_company_information_from_text,
    )


BASE_DIR = Path(__file__).resolve().parent
JSON_PATH = BASE_DIR / "voucher_scanned_accounts.json"
OUTPUT_DIR = BASE_DIR / "ALL"
OUTPUT_LABEL = "ALL"
LOGIN_URL = "https://portal.redwingforbusiness.com/RWS_AccountsListPage?tab=account"


@dataclass
class VoucherRecord:
    rel_path: str
    folder: str
    name: str
    content: str


def _build_program_markdown(driver, account_url: str, program_url: str, program_name: str) -> str:
    driver.get(program_url)
    time.sleep(0.6)

    summary = extract_program_summary_block(driver)
    if not summary:
        try:
            summary = driver.find_element(By.TAG_NAME, "body").text
        except Exception:
            summary = ""

    table_id = find_table_by_headers(driver, ["Style", "Product Name"])
    if not table_id:
        table_id = find_table_by_headers(driver, ["Style"]) or find_table_by_headers(driver, ["Product"])

    headers: List[str] = []
    rows: List[List[str]] = []
    if table_id:
        headers, rows = extract_all_datatable_rows(driver, table_id)
        headers, rows = drop_hidden_product_columns(headers, rows)

    lines: List[str] = []
    lines.append(f"# {program_name}")
    lines.append("")
    lines.append(f"- **Account URL**: {account_url}")
    lines.append(f"- **Program URL**: {program_url}")
    lines.append(f"- **Scraped UTC**: {datetime.now(timezone.utc).isoformat(timespec='seconds')}")
    lines.append("")
    lines.append("## Program Summary")
    lines.append("```text")
    lines.append(summary or "")
    lines.append("```")
    lines.append("")
    lines.append(f"## Product List (all rows) - {len(rows)} rows")

    if headers and rows:
        lines.append("```csv")
        lines.append(",".join(h.replace(",", ";") for h in headers))
        for row in rows:
            clean = [(c or "").replace("\n", " ").replace(",", ";") for c in row]
            if len(clean) < len(headers):
                clean += [""] * (len(headers) - len(clean))
            if len(clean) > len(headers):
                clean = clean[: len(headers)]
            lines.append(",".join(clean))
        lines.append("```")
    else:
        lines.append("_No product table found (or no rows)._")

    return "\n".join(lines) + "\n"


class FileWriter(threading.Thread):
    def __init__(self, output_root: Path, in_q: "queue.Queue[Optional[VoucherRecord]]", log_q: "queue.Queue[tuple]") -> None:
        super().__init__(daemon=True)
        self.output_root = output_root
        self.in_q = in_q
        self.log_q = log_q
        self.saved = 0

    def run(self) -> None:
        self.output_root.mkdir(parents=True, exist_ok=True)
        try:
            while True:
                rec = self.in_q.get()
                if rec is None:
                    self.in_q.task_done()
                    break

                out_path = BASE_DIR.joinpath(*[part for part in rec.rel_path.split("/") if part])
                out_path.parent.mkdir(parents=True, exist_ok=True)
                out_path.write_text(rec.content or "", encoding="utf-8")
                self.saved += 1
                self.log_q.put(("saved", self.saved, rec.rel_path))
                self.in_q.task_done()
        finally:
            self.log_q.put(("writer_done", self.saved, ""))


class VoucherScraperGUI:
    def __init__(self, root: tk.Tk) -> None:
        self.root = root
        self.root.title("RW Voucher File Scraper")
        self.root.geometry("1300x860")

        self.log_q: "queue.Queue[tuple]" = queue.Queue()
        self.write_q: "queue.Queue[Optional[VoucherRecord]]" = queue.Queue(maxsize=500)
        self.task_q: "queue.Queue[tuple]" = queue.Queue()
        self.stop_event = threading.Event()
        self.writer: Optional[FileWriter] = None
        self.workers: List[threading.Thread] = []
        self.dispatcher: Optional[threading.Thread] = None

        self.total_tasks = 0
        self.completed_tasks = 0
        self.failed_tasks = 0
        self.saved_rows = 0
        self.worker_states: Dict[int, str] = {}

        self._build_ui()
        self._load_json_file()
        self._pump_logs()

    def _build_ui(self) -> None:
        top = ttk.Frame(self.root, padding=8)
        top.pack(fill="x")

        self.mode_var = tk.StringVar(value="url")
        ttk.Label(top, text="Mode:").grid(row=0, column=0, sticky="w")
        ttk.Radiobutton(top, text="Pull one URL", variable=self.mode_var, value="url", command=self._refresh_mode).grid(row=0, column=1, sticky="w")
        ttk.Radiobutton(top, text="Pull search term", variable=self.mode_var, value="search", command=self._refresh_mode).grid(row=0, column=2, sticky="w")
        ttk.Radiobutton(top, text="Pull all", variable=self.mode_var, value="all", command=self._refresh_mode).grid(row=0, column=3, sticky="w")

        self.url_var = tk.StringVar()
        self.term_var = tk.StringVar(value="aa")
        self.headless_var = tk.BooleanVar(value=True)
        self.workers_var = tk.IntVar(value=5)

        ttk.Label(top, text="URL:").grid(row=1, column=0, sticky="w")
        self.url_entry = ttk.Entry(top, textvariable=self.url_var, width=120)
        self.url_entry.grid(row=1, column=1, columnspan=5, sticky="ew", padx=(4, 0))

        ttk.Label(top, text="Search:").grid(row=2, column=0, sticky="w")
        self.term_entry = ttk.Entry(top, textvariable=self.term_var, width=40)
        self.term_entry.grid(row=2, column=1, sticky="w", padx=(4, 0))

        ttk.Checkbutton(top, text="Headless", variable=self.headless_var).grid(row=2, column=2, sticky="w", padx=(16, 0))
        ttk.Label(top, text="Workers (1-25):").grid(row=2, column=3, sticky="e")
        self.worker_spin = ttk.Spinbox(top, from_=1, to=25, textvariable=self.workers_var, width=5)
        self.worker_spin.grid(row=2, column=4, sticky="w", padx=(4, 0))

        self.start_btn = ttk.Button(top, text="Start", command=self.start)
        self.start_btn.grid(row=3, column=1, sticky="w", pady=(8, 0))
        self.stop_btn = ttk.Button(top, text="Stop", command=self.stop, state="disabled")
        self.stop_btn.grid(row=3, column=2, sticky="w", pady=(8, 0), padx=(8, 0))

        self.status_var = tk.StringVar(value=f"Output: {OUTPUT_DIR}")
        ttk.Label(top, textvariable=self.status_var).grid(row=3, column=3, columnspan=3, sticky="w", pady=(8, 0), padx=(12, 0))

        for i in range(6):
            top.columnconfigure(i, weight=1 if i in {1, 5} else 0)

        mid = ttk.Panedwindow(self.root, orient="horizontal")
        mid.pack(fill="both", expand=True, padx=8, pady=8)

        left = ttk.Frame(mid, padding=6)
        right = ttk.Frame(mid, padding=6)
        mid.add(left, weight=1)
        mid.add(right, weight=1)

        ttk.Label(left, text="Workers").pack(anchor="w")
        self.worker_tree = ttk.Treeview(left, columns=("status",), show="headings", height=16)
        self.worker_tree.heading("status", text="Status")
        self.worker_tree.pack(fill="x")

        self.progress_var = tk.StringVar(value="Idle")
        ttk.Label(left, textvariable=self.progress_var).pack(anchor="w", pady=(8, 0))

        ttk.Label(left, text="Log").pack(anchor="w", pady=(8, 0))
        self.log_text = ScrolledText(left, height=20, wrap="word")
        self.log_text.pack(fill="both", expand=True)

        ttk.Label(right, text="voucher_scanned_accounts.json (editable)").pack(anchor="w")
        btns = ttk.Frame(right)
        btns.pack(fill="x", pady=(4, 4))
        ttk.Button(btns, text="Load JSON", command=self._load_json_file).pack(side="left")
        ttk.Button(btns, text="Save JSON", command=self._save_json_file).pack(side="left", padx=(8, 0))
        ttk.Button(btns, text="Format JSON", command=self._format_json).pack(side="left", padx=(8, 0))
        ttk.Button(btns, text="Open JSON As...", command=self._open_json_other).pack(side="left", padx=(8, 0))

        self.json_text = ScrolledText(right, wrap="none")
        self.json_text.pack(fill="both", expand=True)

        self._refresh_mode()

    def _refresh_mode(self) -> None:
        mode = self.mode_var.get().strip().lower()
        if mode == "url":
            self.url_entry.configure(state="normal")
            self.term_entry.configure(state="disabled")
        elif mode == "search":
            self.url_entry.configure(state="disabled")
            self.term_entry.configure(state="normal")
        else:
            self.url_entry.configure(state="disabled")
            self.term_entry.configure(state="disabled")

    def _log(self, msg: str) -> None:
        ts = datetime.now().strftime("%H:%M:%S")
        self.log_text.insert("end", f"[{ts}] {msg}\n")
        self.log_text.see("end")

    def _set_running(self, running: bool) -> None:
        self.start_btn.configure(state="disabled" if running else "normal")
        self.stop_btn.configure(state="normal" if running else "disabled")

    def _load_json_file(self) -> None:
        if not JSON_PATH.exists():
            self.json_text.delete("1.0", "end")
            self.json_text.insert("1.0", "{}\n")
            return
        txt = JSON_PATH.read_text(encoding="utf-8", errors="ignore")
        self.json_text.delete("1.0", "end")
        self.json_text.insert("1.0", txt)

    def _save_json_file(self) -> None:
        raw = self.json_text.get("1.0", "end").strip() or "{}"
        try:
            obj = json.loads(raw)
        except Exception as e:
            messagebox.showerror("Invalid JSON", str(e))
            return
        JSON_PATH.write_text(json.dumps(obj, indent=2, sort_keys=False) + "\n", encoding="utf-8")
        self._log(f"Saved JSON: {JSON_PATH}")

    def _format_json(self) -> None:
        raw = self.json_text.get("1.0", "end").strip() or "{}"
        try:
            obj = json.loads(raw)
        except Exception as e:
            messagebox.showerror("Invalid JSON", str(e))
            return
        self.json_text.delete("1.0", "end")
        self.json_text.insert("1.0", json.dumps(obj, indent=2, sort_keys=False) + "\n")

    def _open_json_other(self) -> None:
        path = filedialog.askopenfilename(initialdir=str(BASE_DIR), filetypes=[("JSON", "*.json"), ("All", "*")])
        if not path:
            return
        txt = Path(path).read_text(encoding="utf-8", errors="ignore")
        self.json_text.delete("1.0", "end")
        self.json_text.insert("1.0", txt)

    def _queue_account_tasks(self, term: Optional[str]) -> List[tuple]:
        driver = build_driver(headless=self.headless_var.get())
        try:
            login(driver)
            driver.get(LOGIN_URL)
            time.sleep(0.8)
            filter_accounts_term(driver, term or "")
            accounts = get_all_accounts_from_results(driver)
            tasks = [("account", urljoin(driver.current_url, a.href), a.text or "") for a, _row in accounts if a.href]
            return tasks
        finally:
            driver.quit()

    def _pull_one_url_task(self, url: str) -> List[tuple]:
        return [("url", url.strip(), "")]

    def start(self) -> None:
        if self.dispatcher and self.dispatcher.is_alive():
            return

        workers = max(1, min(25, int(self.workers_var.get() or 1)))
        mode = self.mode_var.get().strip().lower()

        if mode == "url":
            url = self.url_var.get().strip()
            if not url:
                messagebox.showerror("Missing URL", "Enter one URL.")
                return
        else:
            url = ""

        self.total_tasks = 0
        self.completed_tasks = 0
        self.failed_tasks = 0
        self.saved_rows = 0
        self.worker_states.clear()
        self.worker_tree.delete(*self.worker_tree.get_children())

        for wid in range(1, workers + 1):
            iid = str(wid)
            self.worker_tree.insert("", "end", iid=iid, values=("idle",))

        self.stop_event.clear()
        self._set_running(True)
        self.progress_var.set("Preparing...")

        self.writer = FileWriter(OUTPUT_DIR, self.write_q, self.log_q)
        self.writer.start()

        self.dispatcher = threading.Thread(target=self._dispatch, args=(mode, url, workers), daemon=True)
        self.dispatcher.start()

    def stop(self) -> None:
        self.stop_event.set()
        self._log("Stop requested.")

    def _dispatch(self, mode: str, url: str, workers: int) -> None:
        try:
            if mode == "url":
                tasks = self._pull_one_url_task(url)
            elif mode == "search":
                tasks = self._queue_account_tasks(self.term_var.get().strip())
            else:
                tasks = self._queue_account_tasks("")

            if not tasks:
                self.log_q.put(("done", 0, 0))
                self.write_q.put(None)
                return

            self.total_tasks = len(tasks)
            self.log_q.put(("info", f"Queued {self.total_tasks} tasks", ""))

            for t in tasks:
                self.task_q.put(t)

            self.workers = [
                threading.Thread(target=self._worker_loop, args=(wid,), daemon=True)
                for wid in range(1, workers + 1)
            ]
            for t in self.workers:
                t.start()
            for t in self.workers:
                t.join()

            self.write_q.put(None)
            if self.writer:
                self.writer.join()

            self.log_q.put(("done", self.completed_tasks, self.failed_tasks))
        except Exception as e:
            self.log_q.put(("error", f"Dispatcher error: {e}", ""))
            self.write_q.put(None)

    def _worker_loop(self, wid: int) -> None:
        self.log_q.put(("worker", wid, "starting"))
        driver = None
        try:
            driver = build_driver(headless=self.headless_var.get())
            login(driver)
            self.log_q.put(("worker", wid, "ready"))

            while not self.stop_event.is_set():
                try:
                    mode, target, _label = self.task_q.get_nowait()
                except queue.Empty:
                    break

                self.log_q.put(("worker", wid, f"running: {target[:80]}"))
                ok = False
                try:
                    if mode == "account":
                        self._scrape_account(driver, target)
                    else:
                        self._scrape_url(driver, target)
                    ok = True
                except Exception as e:
                    self.log_q.put(("error", f"Worker {wid} failed {target}: {e}", ""))
                finally:
                    if ok:
                        self.completed_tasks += 1
                        self.log_q.put(("progress", self.completed_tasks, self.failed_tasks))
                    else:
                        self.failed_tasks += 1
                        self.log_q.put(("progress", self.completed_tasks, self.failed_tasks))
                    self.task_q.task_done()

            self.log_q.put(("worker", wid, "idle"))
        except Exception as e:
            self.log_q.put(("error", f"Worker {wid} setup failed: {e}", ""))
        finally:
            if driver is not None:
                try:
                    driver.quit()
                except Exception:
                    pass

    def _build_account_folder_label(self, company_name: str, account_number: str, parent_account: str) -> str:
        base = company_name or "Account"
        if account_number:
            base = f"{base}_{account_number}"
        if parent_account:
            base = f"{base}_{parent_account}"
        return _safe_fs_component(base, max_len=220)

    def _enqueue_voucher(self, folder: str, program_name: str, content: str) -> None:
        name = f"{_safe_fs_component(program_name, max_len=200)}.md"
        rel_folder = f"{OUTPUT_LABEL}/{folder}" if folder else OUTPUT_LABEL
        rel = f"{rel_folder}/{name}"
        self.write_q.put(VoucherRecord(rel_path=rel, folder=rel_folder, name=name, content=content))

    def _scrape_account(self, driver, account_url: str) -> None:
        driver.get(account_url)
        time.sleep(0.6)

        try:
            body_text = driver.find_element(By.TAG_NAME, "body").text
        except Exception:
            body_text = ""

        info = parse_company_information_from_text(body_text)
        company_name = get_account_name_from_summary_span(driver) or info.get("company_name", "Account")
        account_number = info.get("account_number", "")
        parent_account = info.get("parent_account", "")
        folder = self._build_account_folder_label(company_name, account_number, parent_account)

        programs = get_account_program_links(driver, only_active=True)
        active = [p for p in programs if bool(p.get("active")) and str(p.get("href") or "").strip()]

        if not active:
            lines = [
                "# Account Summary",
                "",
                f"- **Account URL**: {account_url}",
                f"- **Scraped UTC**: {datetime.now(timezone.utc).isoformat(timespec='seconds')}",
                "",
                "## Account Page Text",
                "```text",
                body_text or "",
                "```",
                "",
            ]
            self._enqueue_voucher(folder, "Account Summary", "\n".join(lines))
            return

        for p in active:
            if self.stop_event.is_set():
                return
            program_name = str(p.get("text") or "Program").strip() or "Program"
            program_url = urljoin(account_url, str(p.get("href") or "").strip())
            content = _build_program_markdown(driver, account_url, program_url, program_name)
            self._enqueue_voucher(folder, program_name, content)
            driver.get(account_url)
            time.sleep(0.4)

    def _scrape_url(self, driver, target_url: str) -> None:
        driver.get(target_url)
        time.sleep(0.6)
        cur = driver.current_url

        if "AccountSummary" in cur:
            self._scrape_account(driver, cur)
            return

        program_name = "Voucher"
        try:
            h = driver.find_elements(By.CSS_SELECTOR, "h1,h2,legend")
            for el in h:
                txt = (el.text or "").strip()
                if txt:
                    program_name = txt
                    break
        except Exception:
            pass

        content = _build_program_markdown(driver, "", cur, program_name)
        self._enqueue_voucher("Direct_URL", program_name, content)

    def _pump_logs(self) -> None:
        try:
            while True:
                msg = self.log_q.get_nowait()
                kind = msg[0]

                if kind == "info":
                    self._log(str(msg[1]))
                elif kind == "error":
                    self._log(f"ERROR: {msg[1]}")
                elif kind == "worker":
                    wid, state = int(msg[1]), str(msg[2])
                    self.worker_states[wid] = state
                    iid = str(wid)
                    if self.worker_tree.exists(iid):
                        self.worker_tree.item(iid, values=(state,))
                elif kind == "saved":
                    self.saved_rows = int(msg[1])
                    self.status_var.set(f"Output: {OUTPUT_DIR} | saved files: {self.saved_rows}")
                elif kind == "progress":
                    done, fail = int(msg[1]), int(msg[2])
                    self.progress_var.set(
                        f"Tasks total={self.total_tasks} done={done} failed={fail} saved={self.saved_rows}"
                    )
                elif kind == "writer_done":
                    self._log(f"Writer finished. Rows saved this run: {msg[1]}")
                elif kind == "done":
                    done, fail = int(msg[1]), int(msg[2])
                    self.progress_var.set(
                        f"Completed. total={self.total_tasks} done={done} failed={fail} saved={self.saved_rows}"
                    )
                    self._set_running(False)

                self.log_q.task_done()
        except queue.Empty:
            pass
        finally:
            self.root.after(150, self._pump_logs)


def main() -> int:
    root = tk.Tk()
    app = VoucherScraperGUI(root)
    root.mainloop()
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
