#
# -*- coding: utf-8 -*-
#

import argparse
import os
from collections import defaultdict

from particles import Particle
from library.SM import particles as SMP, interactions as SMI
from library.NuMSM import particles as NuP, interactions as NuI
from evolution import Universe
from common import UNITS, Params, utils, HeuristicGrid, LogSpacedGrid


parser = argparse.ArgumentParser(description='Run simulation for given mass and mixing angle')
parser.add_argument('--mass', required=True)
parser.add_argument('--theta', required=True)
parser.add_argument('--tau', required=True)
parser.add_argument('--comment', default='')
args = parser.parse_args()

mass = float(args.mass) * UNITS.MeV
theta = float(args.theta)
lifetime = float(args.tau) * UNITS.s

folder = utils.ensure_dir(
    os.path.split(__file__)[0],
    "mass={:e}_tau={:e}_theta={:e}".format(mass / UNITS.MeV, lifetime / UNITS.s, theta)
    + args.comment
)

T_initial = 400. * UNITS.MeV
T_final = 0.0008 * UNITS.MeV
params = Params(T=T_initial,
                dy=0.05)

universe = Universe(params=params, folder=folder)

photon = Particle(**SMP.photon)

electron = Particle(**SMP.leptons.electron)
muon = Particle(**SMP.leptons.muon)
tau = Particle(**SMP.leptons.tau)

neutrino_e = Particle(**SMP.leptons.neutrino_e)
neutrino_mu = Particle(**SMP.leptons.neutrino_mu)
neutrino_tau = Particle(**SMP.leptons.neutrino_tau)

neutral_pion = Particle(**SMP.hadrons.neutral_pion)
charged_pion = Particle(**SMP.hadrons.charged_pion)

sterile = Particle(**NuP.dirac_sterile_neutrino(mass))
sterile_grid = LogSpacedGrid(10, T_initial * 5)
sterile.set_grid(sterile_grid)
sterile.decoupling_temperature = T_initial

grid = HeuristicGrid(mass, lifetime)
for neutrino in [neutrino_e, neutrino_mu, neutrino_tau]:
    neutrino.decoupling_temperature = 0 * UNITS.MeV
    neutrino.set_grid(grid)


universe.add_particles([
    photon,

    electron,
    muon,
    tau,

    neutrino_e,
    neutrino_mu,
    neutrino_tau,

    neutral_pion,
    charged_pion,

    sterile,
])

thetas = defaultdict(float, {
    'electron': theta,
})

universe.interactions += (
    SMI.neutrino_interactions(
        leptons=[electron, muon],
        neutrinos=[neutrino_e, neutrino_mu, neutrino_tau]
    )
    + NuI.sterile_leptons_interactions(
        thetas=thetas, sterile=sterile,
        neutrinos=[neutrino_e, neutrino_mu, neutrino_tau],
        leptons=[electron, muon, tau]
    )
)

if sterile.mass > neutral_pion.mass:
    universe.interactions += (
        NuI.sterile_hadrons_interactions(
            thetas=thetas, sterile=sterile,
            neutrinos=[neutrino_e, neutrino_mu, neutrino_tau],
            leptons=[electron, muon, tau],
            hadrons=[neutral_pion, charged_pion]
        )
    )

universe.init_kawano(electron=electron, neutrino=neutrino_e)
universe.init_oscillations(SMP.leptons.oscillations_map(), (neutrino_e, neutrino_mu, neutrino_tau))

if universe.graphics:
    from plotting import (MassiveParticleMonitor, AbundanceMonitor)
    universe.graphics.monitor([
        (sterile, MassiveParticleMonitor),
        (sterile, AbundanceMonitor)
    ])

universe.evolve(T_final)
if universe.graphics:
    from tests.plots import articles_comparison_plots
    articles_comparison_plots(universe, [neutrino_e, neutrino_mu, neutrino_tau, sterile])