from __future__ import annotations

from functools import reduce
from typing import Literal

import numpy as np
from pydantic import Field, computed_field

from chemex.configuration.base import ExperimentConfiguration, ToBeFitted
from chemex.configuration.conditions import ConditionsWithValidations
from chemex.configuration.data import RelaxationDataSettings
from chemex.configuration.experiment import CpmgSettings
from chemex.configuration.types import Delay, Frequency, PulseWidth
from chemex.containers.data import Data
from chemex.containers.dataset import load_relaxation_dataset
from chemex.experiments.factories import Creators, factories
from chemex.filterers import PlanesFilterer
from chemex.nmr.basis import Basis
from chemex.nmr.liouvillian import LiouvillianIS
from chemex.nmr.spectrometer import Spectrometer
from chemex.parameters.spin_system import SpinSystem
from chemex.plotters.cpmg import CpmgPlotter
from chemex.printers.data import CpmgPrinter
from chemex.typing import Array

Delays = tuple[dict[float, float], dict[float, float], list[float]]


EXPERIMENT_NAME = "cpmg_15n_ip_0013"


class Cpmg15N0013IpSettings(CpmgSettings):
    """Settings for 15N in-phase CPMG experiment with 0013 variant."""

    name: Literal["cpmg_15n_ip_0013"]
    time_t2: Delay = Field(description="Total CPMG relaxation delay (seconds)")
    carrier: Frequency = Field(description="15N carrier frequency (Hz)")
    pw90: PulseWidth = Field(description="15N 90-degree pulse width (seconds)")
    time_equil: float = Field(
        default=0.0,
        ge=0.0,
        description="Equilibration delay (seconds)",
    )
    ncyc_max: int = Field(gt=0, description="Maximum number of CPMG cycles")

    @computed_field  # type: ignore[misc]
    @property
    def t_neg(self) -> float:
        """Calculate negative delay compensation for pulse imperfections.

        Returns:
            Negative delay time in seconds.

        """
        return -2.0 * self.pw90 / np.pi

    @computed_field  # type: ignore[misc]
    @property
    def t_pos(self) -> float:
        """Calculate positive delay compensation for pulse imperfections.

        Returns:
            Positive delay time in seconds.

        """
        return 4.0 * self.pw90 / np.pi

    @computed_field  # type: ignore[misc]
    @property
    def start_terms(self) -> list[str]:
        """Initial magnetization terms for the experiment.

        Returns:
            List of initial state terms for the Liouvillian calculation.

        """
        return [f"iz{self.suffix_start}"]

    @computed_field  # type: ignore[misc]
    @property
    def detection(self) -> str:
        """Detection mode for the observable magnetization.

        Returns:
            Detection term for the Liouvillian calculation.

        """
        return f"[iz{self.suffix_detect}]"


class Cpmg15N0013IpConfig(
    ExperimentConfiguration[
        Cpmg15N0013IpSettings,
        ConditionsWithValidations,
        RelaxationDataSettings,
    ],
):
    @property
    def to_be_fitted(self) -> ToBeFitted:
        state = self.experiment.observed_state
        return ToBeFitted(rates=[f"r2_i_{state}"], model_free=[f"tauc_{state}"])


def build_spectrometer(
    config: Cpmg15N0013IpConfig,
    spin_system: SpinSystem,
) -> Spectrometer:
    settings = config.experiment
    conditions = config.conditions

    basis = Basis(type="ixyz", spin_system="nh")
    liouvillian = LiouvillianIS(spin_system, basis, conditions)
    spectrometer = Spectrometer(liouvillian)

    spectrometer.carrier_i = settings.carrier
    spectrometer.b1_i = 1 / (4.0 * settings.pw90)
    spectrometer.detection = settings.detection

    return spectrometer


class Cpmg15N0013IpSequence:
    """Sequence for 15N in-phase CPMG experiment with 0013 variant."""

    def __init__(self, settings: Cpmg15N0013IpSettings) -> None:
        self.settings = settings
        self._phase_cache: dict[float, Array] = {}

    def _get_delays(self, ncycs: Array) -> Delays:
        ncycs_no_ref = ncycs[ncycs > 0]
        tau_cps = {
            ncyc: self.settings.time_t2 / (4.0 * ncyc) - 0.75 * self.settings.pw90
            for ncyc in ncycs_no_ref
        }
        deltas = {
            ncyc: self.settings.pw90 * (self.settings.ncyc_max - ncyc)
            + self.settings.time_equil
            for ncyc in ncycs_no_ref
        }
        deltas[0.0] = (
            self.settings.pw90 * (self.settings.ncyc_max - 1) + self.settings.time_equil
        )
        delays = [
            self.settings.t_neg,
            self.settings.t_pos,
            self.settings.time_equil,
            *deltas.values(),
            *tau_cps.values(),
        ]
        return tau_cps, deltas, delays

    def _get_phases(self, ncyc: Array) -> Array:
        # Cache phases since they only depend on ncyc value, not parameters
        ncyc_key = float(ncyc)
        if ncyc_key not in self._phase_cache:
            cp_phases = np.array(
                [
                    [0, 0, 1, 3, 0, 0, 3, 1, 0, 0, 3, 1, 0, 0, 1, 3],
                    [1, 3, 2, 2, 3, 1, 2, 2, 3, 1, 2, 2, 1, 3, 2, 2],
                ],
            )
            indexes = np.flip(np.arange(2 * int(ncyc)))
            self._phase_cache[ncyc_key] = np.take(
                cp_phases, indexes, mode="wrap", axis=1
            )
        return self._phase_cache[ncyc_key]

    def calculate(self, spectrometer: Spectrometer, data: Data) -> Array:
        ncycs = data.metadata

        # Calculation of the spectrometers corresponding to all the delays
        tau_cps, deltas, all_delays = self._get_delays(ncycs)
        delays = dict(zip(all_delays, spectrometer.delays(all_delays), strict=True))
        d_neg = delays[self.settings.t_neg]
        d_pos = delays[self.settings.t_pos]
        d_delta = {ncyc: delays[delay] for ncyc, delay in deltas.items()}
        d_cp = {ncyc: delays[delay] for ncyc, delay in tau_cps.items()}

        # Calculation of the spectrometers corresponding to all the pulses
        p90 = spectrometer.p90_i
        p180 = spectrometer.p180_i

        # Getting the starting magnetization
        start = spectrometer.get_start_magnetization(self.settings.start_terms)

        # Calculating the instensities as a function of ncyc
        intensities = {
            0.0: spectrometer.detect(
                d_delta[0]
                @ p90[3]
                @ p180[[0, 3]]
                @ d_pos
                @ p180[[0, 1]]
                @ p90[1]
                @ start,
            ),
        }
        for ncyc in set(ncycs) - {0.0}:
            phases = self._get_phases(ncyc)
            echo = d_cp[ncyc] @ p180 @ d_cp[ncyc]
            cpmg = reduce(np.matmul, echo[phases.T])
            intensities[ncyc] = spectrometer.detect(
                d_delta[ncyc] @ p90[3] @ d_neg @ cpmg @ d_neg @ p90[1] @ start,
            )

        # Return profile
        return np.array([intensities[ncyc] for ncyc in ncycs])

    @staticmethod
    def is_reference(metadata: Array) -> Array:
        return metadata == 0


def register() -> None:
    creators = Creators(
        config_creator=Cpmg15N0013IpConfig,
        spectrometer_creator=build_spectrometer,
        sequence_creator=Cpmg15N0013IpSequence,
        dataset_creator=load_relaxation_dataset,
        filterer_creator=PlanesFilterer,
        printer_creator=CpmgPrinter,
        plotter_creator=CpmgPlotter,
    )
    factories.register(name=EXPERIMENT_NAME, creators=creators)
