from __future__ import annotations

import re
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple

from app.models import ChartDataset, ChartPayload, ChartType, QueryIntent, SqlExecutionResult


def _is_number(v: Any) -> bool:
    return isinstance(v, (int, float)) and not isinstance(v, bool)


def _to_float(v: Any) -> Optional[float]:
    if v is None:
        return None
    if _is_number(v):
        return float(v)
    if isinstance(v, str):
        try:
            return float(v)
        except ValueError:
            return None
    return None


def _column_numeric(rows: List[List[Any]], col_idx: int) -> bool:
    saw = False
    for row in rows:
        if col_idx >= len(row):
            return False
        v = row[col_idx]
        if v is None:
            continue
        saw = True
        if _to_float(v) is None:
            return False
    return saw


def _labels_look_like_dates(labels: List[str]) -> bool:
    if len(labels) < 2:
        return False
    pat = re.compile(r"^\d{4}-\d{2}-\d{2}")
    ok = sum(1 for lb in labels if isinstance(lb, str) and pat.match(lb))
    return ok >= max(2, int(0.7 * len(labels)))


def _aggregate_by_label(
    rows: List[List[Any]], label_idx: int, metric_indices: List[int]
) -> Tuple[List[str], Dict[int, List[float]]]:
    """Sum numeric values when the same category label appears more than once."""
    buckets: Dict[str, Dict[int, float]] = defaultdict(lambda: defaultdict(float))
    order: List[str] = []
    seen: Set[str] = set()
    for row in rows:
        raw = row[label_idx]
        lb = str(raw) if raw is not None else ""
        if lb not in seen:
            seen.add(lb)
            order.append(lb)
        for mj in metric_indices:
            fv = _to_float(row[mj])
            if fv is not None:
                buckets[lb][mj] += fv
    out: Dict[int, List[float]] = {mj: [] for mj in metric_indices}
    for lb in order:
        for mj in metric_indices:
            out[mj].append(float(buckets[lb][mj]))
    return order, out


def build_chart_payload(intent: QueryIntent, result: SqlExecutionResult) -> Optional[ChartPayload]:
    """Build Chart.js-friendly data when the result grid supports categories × metrics."""
    if not result.rows or not result.columns:
        return None

    ncols = len(result.columns)
    numeric_cols = [j for j in range(ncols) if _column_numeric(result.rows, j)]
    categorical_cols = [j for j in range(ncols) if j not in numeric_cols]

    labels: List[str] = []
    datasets: List[ChartDataset] = []

    # One summary row with multiple numeric columns (no category).
    if len(result.rows) == 1 and len(numeric_cols) >= 2 and not categorical_cols:
        row = result.rows[0]
        labels = [result.columns[j] for j in numeric_cols]
        vals = [_to_float(row[j]) for j in numeric_cols]
        if any(v is None for v in vals):
            return None
        datasets.append(ChartDataset(label="Value", data=[float(v) for v in vals]))

    elif categorical_cols and numeric_cols:
        label_idx = categorical_cols[0]
        labels, series_by_metric = _aggregate_by_label(result.rows, label_idx, numeric_cols)
        for mj in numeric_cols:
            datasets.append(ChartDataset(label=result.columns[mj], data=series_by_metric[mj]))

    elif len(numeric_cols) == 1 and ncols == 1:
        mj = numeric_cols[0]
        labels = [f"Row {i + 1}" for i in range(len(result.rows))]
        vals: List[float] = []
        for row in result.rows:
            fv = _to_float(row[mj])
            if fv is None:
                return None
            vals.append(fv)
        datasets.append(ChartDataset(label=result.columns[mj], data=vals))

    else:
        return None

    if not labels or not datasets:
        return None

    title_parts: List[str] = []
    if intent.metrics:
        title_parts.append(", ".join(intent.metrics))
    if intent.dimensions:
        title_parts.append("by " + ", ".join(intent.dimensions))
    title = " ".join(title_parts) if title_parts else "Results"

    default_chart: ChartType = "line" if _labels_look_like_dates(labels) else "bar"
    suggested: List[ChartType] = ["bar", "line", "pie"]
    if len(labels) > 12 or len(datasets) > 1:
        suggested = [t for t in suggested if t != "pie"]
    if not suggested:
        suggested = ["bar", "line"]

    return ChartPayload(
        title=title,
        labels=labels,
        datasets=datasets,
        default_chart=default_chart,
        suggested_types=suggested,
    )
