import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import numpy as np
import astropy.units as u
from expecto import get_spectrum

from jax import numpy as jnp
from fleck.jax import ActiveStar, bin_spectrum


times = np.linspace(-0.04, 0.04, 250)
# times = np.linspace(-0.04, 3.3, 250)
wavelength = np.geomspace(0.5, 5, 101) * u.um

# Download and bin PHOENIX model spectra to compute contrast:
kwargs = dict(
    bins=wavelength,
    min=wavelength.min(),
    max=wavelength.max(),
    log=False
)

phot, cool, hot = [
    bin_spectrum(
        get_spectrum(T_eff=T_eff, log_g=5.0, cache=True), **kwargs
    )
    for T_eff in [2600, 2400, 2800]
]

def plot_transit_contamination(
    active_star, planet_params,
    norm_oot_per_wavelength=True,
    norm_stellar_spectrum=True
):
    lc, contam, X, Y, spectrum_at_transit = active_star.transit_model(**planet_params)
    fig = plt.figure(figsize=(9.5, 5), dpi=150)
    gs = GridSpec(2, 2, figure=fig)

    ax = [
        fig.add_subplot(gs[0, 0]),
        fig.add_subplot(gs[1, 0]),
        fig.add_subplot(gs[:, 1:3]),
    ]

    skip = (len(active_star.wavelength) - 1) // 10

    cmap = lambda i: plt.cm.Spectral_r(
        (active_star.wavelength[i] - active_star.wavelength.min()) /
        active_star.wavelength.ptp()
    )

    if norm_stellar_spectrum:
        scale_relative_to_flux_at_wavelength = 1
    else:
        scale_relative_to_flux_at_wavelength = (
            spectrum_at_transit / spectrum_at_transit.mean()
        )[::skip]

    for i, lc_i in enumerate(
        (lc * scale_relative_to_flux_at_wavelength)[:, ::skip].T
    ):

        if norm_oot_per_wavelength:
            lc_i /= lc_i.mean()

        ax[0].plot(active_star.times, lc_i, color=cmap(skip * i))

    ax[0].set(
        xlabel='Time [d]',
        ylabel='$\\left(F(t)/\\bar{F}\\right)_{\\lambda}$',
    )

    contaminated_depth = 1e6 * contam

    ax[1].plot(
        active_star.wavelength * 1e6,
        contaminated_depth,
        zorder=-3, lw=2.5, color='silver'
    )
    ax[1].scatter(
        active_star.wavelength[::skip] * 1e6, contaminated_depth[::skip].T,
        c=cmap(skip * np.arange(len(active_star.wavelength) // skip + 1)),
        s=50, edgecolor='k', zorder=4
    )
    ax[1].set(
        xlabel='Wavelength [µm]',
        ylabel='Transit depth [ppm]',
        xscale='log',
        xlim=[
            1e6 * 0.9 * active_star.wavelength.min(),
            1e6 * 1.1 * active_star.wavelength.max()
        ],
    )

    active_star.plot_star(
        t0=planet_params['t0'],
        rp=planet_params['rp'],
        a=planet_params['a'],
        ecc=planet_params['ecc'],
        inclination=planet_params['inclination'],
        ax=ax[2]
    )

    for sp in ['right', 'top']:
        for axis in ax:
            axis.spines[sp].set_visible(False)

    fig.tight_layout()
    plt.show()

# stellar parameters:
active_star = ActiveStar(
    times=times,
    inclination=np.pi/2,
    T_eff=2600,
    wavelength=phot.wavelength.to_value(u.m),
    phot=phot.flux.value,
)

# add a cool spot:
active_star.add_spot(
    lon=-0.2,  # [rad]
    lat=1.65,  # [rad]
    rad=0.15,  # [R_star]
    spectrum=cool.flux.value,
    temperature=cool.meta['PHXTEFF']
)

# add a hot spot:
active_star.add_spot(
    lon=0.95,
    lat=1.75,
    rad=0.08,
    spectrum=hot.flux.value,
    temperature=hot.meta['PHXTEFF']
)

# planet parameters for TRAPPIST-1 c from Agol 2021:
t1c = dict(
    inclination = np.radians(89.778),
    a = 28.549,
    rp = 0.08440,
    period = 2.421937,
    t0 = 0,
    ecc = 0,
    u1 = 0.1,
    u2 = 0.3
)

plot_transit_contamination(active_star, t1c)