from __future__ import annotations

import collections
from typing import TYPE_CHECKING
from typing import Any
from typing import ClassVar
from typing import Iterator
from typing import Mapping
from typing import Sequence

import pyarrow as pa
import pyarrow.compute as pc

from narwhals._arrow.utils import cast_to_comparable_string_types
from narwhals._arrow.utils import extract_py_scalar
from narwhals._compliant import EagerGroupBy
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals.utils import generate_temporary_column_name

if TYPE_CHECKING:
    from narwhals._arrow.dataframe import ArrowDataFrame
    from narwhals._arrow.expr import ArrowExpr
    from narwhals._arrow.typing import Incomplete
    from narwhals._compliant.group_by import NarwhalsAggregation


class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr"]):
    _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]] = {
        "sum": "sum",
        "mean": "mean",
        "median": "approximate_median",
        "max": "max",
        "min": "min",
        "std": "stddev",
        "var": "variance",
        "len": "count",
        "n_unique": "count_distinct",
        "count": "count",
    }

    def __init__(
        self,
        compliant_frame: ArrowDataFrame,
        keys: Sequence[str],
        /,
        *,
        drop_null_keys: bool,
    ) -> None:
        if drop_null_keys:
            self._compliant_frame = compliant_frame.drop_nulls(keys)
        else:
            self._compliant_frame = compliant_frame
        self._keys: list[str] = list(keys)
        self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)

    def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
        self._ensure_all_simple(exprs)
        aggs: list[tuple[str, str, Any]] = []
        expected_pyarrow_column_names: list[str] = self._keys.copy()
        new_column_names: list[str] = self._keys.copy()

        for expr in exprs:
            output_names, aliases = evaluate_output_names_and_aliases(
                expr, self.compliant, self._keys
            )

            if expr._depth == 0:
                # e.g. `agg(nw.len())`
                if expr._function_name != "len":  # pragma: no cover
                    msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
                    raise AssertionError(msg)

                new_column_names.append(aliases[0])
                expected_pyarrow_column_names.append(f"{self._keys[0]}_count")
                aggs.append((self._keys[0], "count", pc.CountOptions(mode="all")))
                continue

            function_name = self._leaf_name(expr)
            if function_name in {"std", "var"}:
                option: Any = pc.VarianceOptions(ddof=expr._call_kwargs["ddof"])
            elif function_name in {"len", "n_unique"}:
                option = pc.CountOptions(mode="all")
            elif function_name == "count":
                option = pc.CountOptions(mode="only_valid")
            else:
                option = None

            function_name = self._remap_expr_name(function_name)
            new_column_names.extend(aliases)
            expected_pyarrow_column_names.extend(
                [f"{output_name}_{function_name}" for output_name in output_names]
            )
            aggs.extend(
                [(output_name, function_name, option) for output_name in output_names]
            )

        result_simple = self._grouped.aggregate(aggs)

        # Rename columns, being very careful
        expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list)
        for idx, item in enumerate(expected_pyarrow_column_names):
            expected_old_names_indices[item].append(idx)
        if not (
            set(result_simple.column_names) == set(expected_pyarrow_column_names)
            and len(result_simple.column_names) == len(expected_pyarrow_column_names)
        ):  # pragma: no cover
            msg = (
                f"Safety assertion failed, expected {expected_pyarrow_column_names} "
                f"got {result_simple.column_names}, "
                "please report a bug at https://github.com/narwhals-dev/narwhals/issues"
            )
            raise AssertionError(msg)
        index_map: list[int] = [
            expected_old_names_indices[item].pop(0) for item in result_simple.column_names
        ]
        new_column_names = [new_column_names[i] for i in index_map]
        result_simple = result_simple.rename_columns(new_column_names)
        if self.compliant._backend_version < (12, 0, 0):
            columns = result_simple.column_names
            result_simple = result_simple.select(
                [*self._keys, *[col for col in columns if col not in self._keys]]
            )
        return self.compliant._with_native(result_simple)

    def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
        col_token = generate_temporary_column_name(
            n_bytes=8, columns=self.compliant.columns
        )
        null_token: str = "__null_token_value__"  # noqa: S105

        table = self.compliant.native
        it, separator_scalar = cast_to_comparable_string_types(
            *(table[key] for key in self._keys), separator=""
        )
        # NOTE: stubs indicate `separator` must also be a `ChunkedArray`
        # Reality: `str` is fine
        concat_str: Incomplete = pc.binary_join_element_wise
        key_values = concat_str(
            *it,
            separator_scalar,
            null_handling="replace",
            null_replacement=null_token,
        )
        table = table.add_column(i=0, field_=col_token, column=key_values)
        for v in pc.unique(key_values):
            t = self.compliant._with_native(
                table.filter(pc.equal(table[col_token], v)).drop([col_token])
            )
            row = t.simple_select(*self._keys).row(0)
            yield tuple(extract_py_scalar(el) for el in row), t
