"""
1HN Pure Anti-phase Proton CPMG with [0013] Phase Cycle
=======================================================

Analyzes amide proton chemical exchange that is maintained as anti-phase
magnetization throughout the CPMG block. This results in lower intrinsic
relaxation rates and therefore better sensitivity. The calculations use
the (6n)×(6n), two-spin matrix, where n is the number of states::

    { Hx(a), Hy(a), Hz(a), 2HxNz(a), 2HyNz(a), 2HzNz(a),
      Hx(b), Hy(b), Hz(b), 2HxNz(b), 2HyNz(b), 2HzNz(b), ... }

This version is modified such that CPMG pulses are applied with [0013]
phase cycle in order to help better overcome off-resonance effects.

References
----------

Yuwen and Kay. J Biomol NMR (2019) 73:641-650


Note
----

A sample configuration file for this module is available using the command::

    $ chemex config cpmg_1hn_ap_0013

"""
import functools as ft

import numpy as np

import chemex.experiments.helper as ceh
import chemex.helper as ch
import chemex.nmr.liouvillian as cnl


_SCHEMA = {
    "type": "object",
    "properties": {
        "experiment": {
            "type": "object",
            "properties": {
                "time_t2": {"type": "number"},
                "carrier": {"type": "number"},
                "pw90": {"type": "number"},
                "ncyc_max": {"type": "integer"},
                "taua": {"type": "number", "default": 2.38e-3},
                "ipap_flg": {"type": "boolean", "default": False},
                "eburp_flg": {"type": "boolean", "default": False},
                "reburp_flg": {"type": "boolean", "default": False},
                "pw_eburp": {"type": "number", "default": 1.4e-3},
                "pw_reburp": {"type": "number", "default": 1.52e-3},
                "observed_state": {
                    "type": "string",
                    "pattern": "[a-z]",
                    "default": "a",
                },
            },
            "required": ["time_t2", "carrier", "pw90", "ncyc_max"],
        }
    },
}


def read(config):
    ch.validate(config, _SCHEMA)
    config["basis"] = cnl.Basis(type="ixyzsz", spin_system="hn")
    config["fit"] = _fit_this()
    return ceh.load_experiment(config=config, pulse_seq_cls=PulseSeq)


def _fit_this():
    return {
        "rates": ["r2_i_{observed_state}"],
        "model_free": ["tauc_{observed_state}"],
    }


class PulseSeq:
    def __init__(self, config, propagator):
        self.prop = propagator
        settings = config["experiment"]
        self.time_t2 = settings["time_t2"]
        self.prop.carrier_i = settings["carrier"]
        self.pw90 = settings["pw90"]
        self.t_neg = -2.0 * self.pw90 / np.pi
        self.prop.b1_i = 1 / (4.0 * self.pw90)
        self.taua = settings["taua"]
        self.ncyc_max = settings["ncyc_max"]
        self.ipap_flg = settings["ipap_flg"]
        self.eburp_flg = settings["eburp_flg"]
        self.reburp_flg = settings["reburp_flg"]
        self.pw_eburp = settings["pw_eburp"]
        self.pw_reburp = settings["pw_reburp"]
        self.observed_state = settings["observed_state"]
        self.prop.detection = f"[iz_{self.observed_state}]"
        self.p90_i = self.prop.perfect90_i
        self.p180_isx = self.prop.perfect180_i[0] @ self.prop.perfect180_s[0]

    @ft.lru_cache(maxsize=10000)
    def calculate(self, ncycs, params_local):
        self.prop.update(params_local)

        # Calculation of the propagators corresponding to all the delays
        tau_cps, deltas, all_delays = self._get_delays(ncycs)
        delays = dict(zip(all_delays, self.prop.delays(all_delays)))
        d_neg = delays[self.t_neg]
        d_taua = delays[self.taua]
        d_eburp = delays[self.pw_eburp]
        d_reburp = delays[0.5 * self.pw_reburp]
        d_delta = {ncyc: delays[delay] for ncyc, delay in deltas.items()}
        d_cp = {ncyc: delays[delay] for ncyc, delay in tau_cps.items()}

        # Calculation of the propagators corresponding to pulses
        p90 = self.prop.p90_i
        p180 = self.prop.p180_i

        # Calculation of the propagators for INEPT and purge elements
        inept = self.p90_i[1] @ d_taua @ self.p180_isx @ d_taua @ self.p90_i[0]
        zfilter = self.prop.zfilter

        # Getting the starting magnetization
        start = self.prop.get_start_magnetization(terms=f"2izsz_{self.observed_state}")

        # Calculating the central refocusing block
        if self.eburp_flg:
            p180pmy = p180[[1, 3]]
            pp90pmy = self.prop.perfect90_i[[1, 3]]
            e180e_pmy = pp90pmy @ d_eburp @ p180pmy @ d_eburp @ pp90pmy
            middle = [p180pmy @ e180e_pmy, e180e_pmy @ p180pmy]
        elif self.reburp_flg:
            pp180pmy = self.prop.perfect180_i[[1, 3]]
            middle = d_reburp @ pp180pmy @ d_reburp
        else:
            middle = p180[[1, 3]]
        middle = np.mean(middle, axis=0)

        # Calculating the intensities as a function of ncyc
        centre = {0: d_delta[0] @ p90[0] @ middle @ p90[0]}

        for ncyc in set(ncycs) - {0}:
            phases1, phases2 = self._get_phases(ncyc)
            echo = d_cp[ncyc] @ p180 @ d_cp[ncyc]
            cpmg1 = ft.reduce(np.matmul, echo[phases1.T])
            cpmg2 = ft.reduce(np.matmul, echo[phases2.T])
            centre[ncyc] = (
                d_delta[ncyc] @ p90[0] @ d_neg @ cpmg2 @ middle @ cpmg1 @ d_neg @ p90[0]
            )

        intst = {
            ncyc: self.prop.detect(inept @ zfilter @ centre[ncyc] @ start)
            for ncyc in set(ncycs)
        }

        if self.ipap_flg:
            intst = {
                ncyc: intst[ncyc]
                + self.prop.detect(centre[ncyc] @ zfilter @ inept @ start)
                for ncyc in set(ncycs)
            }

        # Return profile
        return np.array([intst[ncyc] for ncyc in ncycs])

    @ft.lru_cache()
    def _get_delays(self, ncycs):
        ncycs_ = np.asarray(ncycs)
        ncycs_ = ncycs_[ncycs_ > 0]
        tau_cps = dict(zip(ncycs_, self.time_t2 / (4.0 * ncycs_) - 0.75 * self.pw90))
        deltas = dict(zip(ncycs_, self.pw90 * (self.ncyc_max - ncycs_)))
        deltas[0] = self.pw90 * self.ncyc_max
        delays = [
            self.taua,
            self.t_neg,
            self.pw_eburp,
            0.5 * self.pw_reburp,
            *tau_cps.values(),
            *deltas.values(),
        ]

        return tau_cps, deltas, delays

    @staticmethod
    @ft.lru_cache()
    def _get_phases(ncyc):
        cp_phases1 = [
            [1, 1, 2, 0, 1, 1, 0, 2, 1, 1, 0, 2, 1, 1, 2, 0],
            [2, 0, 3, 3, 0, 2, 3, 3, 0, 2, 3, 3, 2, 0, 3, 3],
        ]
        cp_phases2 = [
            [3, 3, 2, 0, 3, 3, 0, 2, 3, 3, 0, 2, 3, 3, 2, 0],
            [2, 0, 1, 1, 0, 2, 1, 1, 0, 2, 1, 1, 2, 0, 1, 1],
        ]
        phases1 = np.take(cp_phases1, np.flip(np.arange(ncyc)), mode="wrap", axis=1)
        phases2 = np.take(cp_phases2, np.arange(ncyc), mode="wrap", axis=1)
        return phases1, phases2

    def ncycs_to_nu_cpmgs(self, ncycs):
        ncycs_ = np.asarray(ncycs)
        ncycs_ = ncycs_[ncycs_ > 0]
        return ncycs_ / self.time_t2
