Source code for fleck.jax

from jax import jit, numpy as jnp, random, lax, vmap
from jax.tree_util import register_pytree_node_class
from jax.scipy.integrate import trapezoid

import numpy as np

import astropy.units as u
import jaxoplanet.core

import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from matplotlib.colors import to_hex

from scipy.stats import binned_statistic
from specutils import Spectrum1D

key = random.PRNGKey(0)

empty = jnp.array([])

__all__ = [
    'ActiveStar', 'bin_spectrum'
]


[docs] @register_pytree_node_class class ActiveStar: """ Model for a star with active regions and rotation, with optional planetary transit models and spot occultations. """ n_mc = 1_000 # Number of Monte Carlo samples to use when computing planet+spot overlap key = random.PRNGKey(0) # random key seed def __init__( self, times=empty, lon=empty, lat=empty, rad=empty, spectrum=empty, T_eff=None, temperature=empty, inclination=empty, wavelength=None, phot=None, P_rot=3.3 ): """ Parameters ---------- times : array Times at which to compute the flux lon : array Active region longitudes in radians on (0, 2pi) lat : array Active region latitudes in radians on (0, pi) rad : array Active region radii in units of stellar radii spectrum : array One spectrum for each active region T_eff : array Effective temperature of the photosphere temperature : array Effective temperature of the active regions inclination : array Stellar inclination [radians] wavelength : array Wavelength for each flux observation in ``phot`` [meters] phot : array Photospheric flux at each ``wavelength``. P_rot : float Stellar rotation period """ self.times = jnp.array(times) self.lon = jnp.array(lon) self.lat = jnp.array(lat) self.rad = jnp.array(rad) self.spectrum = jnp.array(spectrum) self.T_eff = T_eff self.temperature = jnp.array(temperature) self.inclination = inclination self.wavelength = wavelength self.phot = phot self.P_rot = P_rot
[docs] def tree_flatten(self): children = ( self.times, self.lon, self.lat, self.rad, self.spectrum, self.T_eff, self.temperature, self.inclination, self.wavelength, self.phot, self.P_rot, ) aux_data = None return children, aux_data
[docs] @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children)
[docs] @jit def rotation_model(self, f0=0, t0_rot=0, u1=0, u2=0): """ Spectrophotometry of stellar rotation. Parameters ---------- f0 : float Baseline flux of an unspotted star (usually zero or one) t0_rot : float Zero-point reference time for stellar rotation Returns ------- spot_model : array Relative flux as a function of time and wavelength """ ( spot_position_x, spot_position_y, spot_position_z, major_axis, minor_axis, angle, rad, contrast ) = self.spot_coords(t0_rot=t0_rot) rsq = spot_position_x ** 2 + spot_position_y ** 2 mu = jnp.sqrt(1 - rsq) mask_behind_star = jnp.where( spot_position_z < 0, mu, 0 ) radial_coord = 1 - jnp.geomspace(1e-5, 1, 100)[::-1] unspotted_total_flux = trapezoid( y=( 2 * np.pi * radial_coord * self.limb_darkening(radial_coord, u1, u2) ), x=radial_coord ) # Morris 2020 Eqn 6-7 spot_model = f0 - jnp.sum( np.pi * rad ** 2 * (1 - contrast) * self.limb_darkening(mu, u1, u2) * mask_behind_star, axis=1 ) / unspotted_total_flux f_S = rad ** 2 * mu * (spot_position_z < 0).astype(int) return spot_model, f_S
[docs] @jit def spot_coords(self, times=None, t0_rot=0): """ Compute the spatial coordinates and projected dimensions of active regions. Parameters ---------- times : array Times on which to compute spectrophotometry t0_rot : float Zero-point reference time for stellar rotation Returns ------- spot_position_x : array x-position of the active region in the observer oriented coordinate system [1]_. spot_position_y : array y-position of the active region in the observer oriented coordinate system [1]_. spot_position_z : array y-position of the active region in the observer oriented coordinate system [1]_. major_axis : array Apparent semimajor axis of the circular active region, which is elliptical when projected active onto the sky plane (in general) minor_axis : array Apparent semiminor axis of the circular active region, which is elliptical when projected active onto the sky plane (in general) angle : array Angle between the +x-axis and the projected active region's semimajor axis rad : array Active region radius [stellar radii] contrast: array Ratio of the active region spectrum and the photosphere spectrum References ---------- .. [1] Fabrycky & Winn (2009) https://arxiv.org/abs/0902.0737 """ contrast = self.spectrum / self.phot[None, :] if contrast.ndim == 1: contrast = contrast[None, :] if times is None: times = self.times """ Limits: lat: (0, pi) lon: (0, 2pi) rad: (0, None) contrast: (0, inf) inclination: (0, pi/2) broadcasting dimensions: 0. phase 1. spot location (lat, lon, rad) 2. contrast/wavelength 3. inclination """ phase = jnp.expand_dims(2 * np.pi * (times - t0_rot) / self.P_rot, [1, 2, 3]) lon = jnp.expand_dims(self.lon, [0, 2, 3]) lat = jnp.expand_dims(self.lat, [0, 2, 3]) rad = jnp.expand_dims(self.rad, [0, 2, 3]) contrast = jnp.expand_dims(contrast, [0, 3]) inclination = jnp.expand_dims(jnp.asarray(self.inclination), [0, 1, 2]) comp_inclination = np.pi / 2 - inclination phi = np.pi / 2 - phase - lon sin_lat = jnp.sin(lat) cos_lat = jnp.cos(lat) sin_c_inc = jnp.sin(comp_inclination) cos_c_inc = jnp.cos(comp_inclination) spot_position_x = ( jnp.cos(phi - np.pi / 2) * sin_c_inc * sin_lat + cos_c_inc * cos_lat ) spot_position_y = -jnp.sin(phi - np.pi / 2) * sin_lat spot_position_z = ( cos_lat * sin_c_inc - jnp.sin(phi) * cos_c_inc * sin_lat ) rsq = spot_position_x ** 2 + spot_position_y ** 2 major_axis = rad minor_axis = rad * jnp.sqrt(1 - rsq) angle = -jnp.degrees(jnp.arctan2(spot_position_y, spot_position_x)) return ( spot_position_x, spot_position_y, spot_position_z, major_axis, minor_axis, angle, rad, contrast )
[docs] def add_spot(self, lon, lat, rad, contrast=None, temperature=None, spectrum=None): """ Add an active region to the stellar model. Parameters ---------- lon : float Active region longitudes in radians on (0, 2pi) lat : float Active region latitudes in radians on (0, pi) rad : float Active region radii in units of stellar radii contrast : float Ratio of the active region's flux to the photospheric flux at each ``ActiveStar.wavelength`` spectrum : float The spectrum of the active region on the same wavelength grid is ``ActiveStar.phot`` """ if contrast is None and spectrum is None and temperature is not None: self.phot = self._blackbody(self.wavelength, self.T_eff) spectrum = self._blackbody(self.wavelength, temperature) for attr, new_value in zip("lon, lat, rad, spectrum, temperature".split(', '), [lon, lat, rad, spectrum, temperature]): prop = getattr(self, attr) if not hasattr(new_value, 'ndim'): new_value = jnp.array([new_value]) if prop is not None: if prop.ndim > 1 or (len(prop) > 1 and len(prop) == len(new_value)): new_value = jnp.vstack([prop, new_value]) else: new_value = jnp.concatenate([prop, new_value]) setattr(self, attr, new_value)
@jit def _blackbody(self, wavelength_meters, temperature): """ Compute a blackbody spectrum. """ h = 6.62607015e-34 # J s c = 299792458.0 # m/s k_B = 1.380649e-23 # J/K return ( 2 * h * c ** 2 / jnp.power(wavelength_meters, 5) / jnp.expm1(h * c / (wavelength_meters * k_B * temperature)) )
[docs] @jit def limb_darkening(self, mu, u1, u2): """ Compute quadratic limb darkening as a function of :math:`\\mu`. """ return ( 1 / np.pi * (1 - u1 * (1 - mu) - u2 * (1 - mu) ** 2) / (1 - u1 / 3 - u2 / 6) )
[docs] @jit def transit_model(self, t0, period, rp, a, inclination, omega=np.pi / 2, ecc=0, f0=1, t0_rot=0, u1=0, u2=0): """ Compute spectrophotometry with rotation and a planetary transit. The transit is computed with ``jaxoplanet`` for a star with quadratic limb darkening. Parameters ---------- t0 : float Mid-transit time period : float Orbital period of the transiting planet rp : float Exoplanet radius in units of stellar radii a : float Planetary semi-major axis in units of stellar radii inclination : float Planetary orbital inclination [radians] omega : float Argument of periapse [radians], default is :math:`\\pi/2`. ecc : float Orbital eccentricity, default is zero. f0 : float Out-of-transit flux for an unspotted star, default is one. t0_rot : float Zero-point in time for stellar rotation, default is zero u1 : float Limb-darkening parameter :math:`u_1` u2 : float Limb-darkening parameter :math:`u_2` Returns ------- lc : array Flux as a function of time and wavelength apparent_rprs2 : array The apparent squared ratio of planet-to-star radius with stellar spectral contamination by active regions X : array x-position of the planet in the observer oriented coordinate system [1]_. Y : array y-position of the planet in the observer oriented coordinate system [1]_. References ---------- .. [1] Fabrycky & Winn (2009) https://arxiv.org/abs/0902.0737 """ u1 = jnp.atleast_1d(u1) u2 = jnp.atleast_1d(u2) u_ld = jnp.column_stack([u1, u2]) # handle the out-of-transit spectroscopic rotational modulation: ( spot_position_x, spot_position_y, spot_position_z, major_axis, minor_axis, angle, rad, contrast ) = self.spot_coords(t0_rot=t0_rot) rsq = spot_position_x ** 2 + spot_position_y ** 2 mu = jnp.sqrt(1 - rsq) mask_behind_star = jnp.where( spot_position_z < 0, mu, 0 ) radial_coord = 1 - jnp.geomspace(1e-5, 1, 100)[::-1] unspotted_total_flux = trapezoid( y=( 2 * np.pi * radial_coord[:, None] * self.limb_darkening( radial_coord[:, None], *u_ld.T ) ).T, x=radial_coord ) limb_dark = self.limb_darkening( mu, u1=u1[None, None, :, None], u2=u2[None, None, :, None] ) # Morris 2020 Eqn 6-7 out_of_transit = f0 - jnp.sum( np.pi * rad ** 2 * (1 - contrast) * limb_dark * mask_behind_star / unspotted_total_flux[None, None, :, None], axis=1 ) f_S = rad ** 2 * mu * (spot_position_z < 0).astype(int) # compute the transit model mean_anomaly = 2 * np.pi * (self.times - t0) / period true_anomaly = jnp.arctan2( *jaxoplanet.core.kepler(M=mean_anomaly, ecc=ecc) ) # Winn 2011 Eqn 1 r = a * (1 - ecc ** 2) / (1 + ecc * jnp.cos(true_anomaly)) # Winn 2011 Eqn 3-4 X = -r * jnp.cos(omega + true_anomaly) Y = -r * jnp.sin(omega + true_anomaly) * jnp.cos(inclination) photosphere = (1 - f_S[..., 0].sum(axis=1)) * self.phot[None, :] spot_coverages, spot_spectra = jnp.broadcast_arrays( f_S[..., 0], self.spectrum[None, ...] ) time_series_spectrum = jnp.squeeze( # photospheric component: photosphere + # sum of the active region components: jnp.sum(spot_coverages * spot_spectra, axis=1) ) transit = vmap( lambda u_ld: jaxoplanet.core.light_curve( u1=u_ld[0], u2=u_ld[1], b=jnp.hypot(X, Y), r=rp ), in_axes=0, out_axes=1 )(u_ld) contaminated_transit = ( time_series_spectrum - jnp.abs(transit) * self.phot[None, :] ) / time_series_spectrum t_ind = jnp.argmin(jnp.abs(self.times - t0)) uncontaminated_max_depth = - transit[t_ind] contaminated_max_depth = ( contaminated_transit.max(0) - contaminated_transit[t_ind] ) / contaminated_transit.max(0) depth_ratio = contaminated_max_depth / uncontaminated_max_depth apparent_rprs2 = rp ** 2 * depth_ratio planet_spot_distance = jnp.hypot( spot_position_y - X[:, None, None, None], spot_position_x - Y[:, None, None, None] ) occultation_possible = jnp.squeeze( (planet_spot_distance < (major_axis + rp)) & (spot_position_z < 0) ) @jit def time_step( carry, j, X=X, Y=Y, spot_position_y=spot_position_y, spot_position_x=spot_position_x, major_axis=major_axis, minor_axis=minor_axis, rp=rp, angle=angle, occultation_possible=occultation_possible ): return carry, lax.cond( jnp.any(occultation_possible[j]), lambda x: self._area_union_per_time( x0_ellipse=spot_position_y[j], y0_ellipse=spot_position_x[j], x0_circle=X[j], y0_circle=Y[j], alpha=jnp.squeeze(major_axis[j]), beta=jnp.squeeze(minor_axis[j]), angle=jnp.squeeze(angle[j]), radius=rp, occultation_possible=occultation_possible[j], ), lambda x: jnp.zeros((spot_position_x.shape[1], self.n_mc), dtype=bool), False ) occultation_per_time_per_spot_per_mc_sample = lax.scan( time_step, 0.0, jnp.arange(self.times.shape[0]) )[1] # shape: (n_times, n_spots, n_mc_samples) frac_occulted_per_time_per_spot = jnp.count_nonzero( occultation_per_time_per_spot_per_mc_sample, axis=2 ) / self.n_mc occultation = ( (1 - contrast) * jnp.expand_dims(frac_occulted_per_time_per_spot, axis=(2, 3)) ) scaled_occultation = (1 - contaminated_transit) * jnp.sum(occultation, axis=1)[..., 0] spectrum_at_transit = time_series_spectrum[t_ind] return ( out_of_transit[..., 0] * (contaminated_transit + scaled_occultation), apparent_rprs2, X, Y, spectrum_at_transit )
@jit def _area_union_per_time( self, x0_ellipse, y0_ellipse, x0_circle, y0_circle, alpha, beta, angle, radius, occultation_possible, ): # Monte Carlo sampling for points inside the planet's disk: key, subkey = random.split(self.key) theta_p = random.uniform(key, minval=0, maxval=2 * np.pi, shape=(self.n_mc,)) key, subkey = random.split(key) rad_p = random.uniform(subkey, minval=0, maxval=radius, shape=(self.n_mc,)) xp = rad_p * jnp.cos(theta_p) + x0_circle yp = rad_p * jnp.sin(theta_p) + y0_circle # ensure overlap only occurs on the stellar surface on_star = jnp.hypot(xp, yp) < 1 # ensure the one-spot case is indexed correctly below: x0_ellipse = jnp.atleast_1d(x0_ellipse) y0_ellipse = jnp.atleast_1d(y0_ellipse) alpha = jnp.atleast_1d(alpha) beta = jnp.atleast_1d(beta) angle = jnp.atleast_1d(angle) @jit def find_overlap(k): # find overlap between the planet and the elliptical region (projected circular spot) in_ellipse = jnp.hypot( ((xp - x0_ellipse[k]) * jnp.cos(jnp.radians(angle[k])) + (yp - y0_ellipse[k]) * jnp.sin(jnp.radians(angle[k]))) / alpha[k], ((xp - x0_ellipse[k]) * jnp.sin(jnp.radians(angle[k])) - (yp - y0_ellipse[k]) * jnp.cos(jnp.radians(angle[k]))) / beta[k] ) < 1 return in_ellipse & on_star @jit def spot_step(carry, k, occultation_possible=occultation_possible): # where occultations are possible, compute the overlap occultation_possible = jnp.atleast_1d(occultation_possible) return carry, lax.cond( occultation_possible[k], lambda x: jnp.squeeze(find_overlap(k)), lambda x: jnp.zeros(self.n_mc, dtype=bool), False ) monte_carlo_occulted_area = lax.scan(spot_step, 0, jnp.arange(x0_ellipse.shape[0]))[1] return monte_carlo_occulted_area
[docs] def plot_star(self, t0, rp, a, inclination, ecc=0, t0_rot=0, multiply_radii=1, ax=None, annotate=False): """ Plot a 2D representation of the star and transit chord. Parameters ---------- t0 : float Mid-transit time rp : float Exoplanet radius in units of stellar radii a : float Planetary semi-major axis in units of stellar radii inclination : float Planetary orbital inclination [radians] ecc : float Orbital eccentricity, default is zero. t0_rot : float Zero-point in time for stellar rotation, default is zero multiply_radii : float Visually represent scaled-up active regions where the radii are increased by factor ``multiply_radii``, default is one. ax : matplotlib.axes.Axes Add the visualization to this matplotlib axis annotate : bool Add a text label with active region indices and temperatures to the visualization """ if ax is None: ax = plt.gca() log_temps = np.log10(self.temperature) def temp_cmap(x): return to_hex( plt.cm.YlOrRd_r( (np.log10(x) - min(log_temps)) / (max(log_temps) - min(log_temps)) * 0.6 + 0.4 ) ) star = plt.Circle((0, 0), 1, color=to_hex(temp_cmap(self.T_eff))) ax.add_patch(star) ax.set(xlim=[-1.05, 1.05], ylim=[-1.05, 1.05]) squeezed_coords = list(map( jnp.squeeze, self.spot_coords(times=jnp.array([t0]), t0_rot=t0_rot) )) for i, (x, y, z, _, _, _, _, angle) in enumerate(zip(*squeezed_coords)): if z < 0: rsq = x ** 2 + y ** 2 short = np.sqrt(1 - rsq) angle = -np.degrees(np.arctan2(y, x)) ell = Ellipse( (y, x), width=multiply_radii * 2 * self.rad[i], height=multiply_radii * 2 * self.rad[i] * short, angle=angle, facecolor=temp_cmap(self.temperature[i]), edgecolor='k' ) ax.add_patch(ell) if annotate: ax.annotate( f"{i+1}: {int(self.temperature[i])} K", (y, x), va='center', ha='center', fontsize=6 ) ax.set_aspect('equal') b = (a * np.cos(inclination) * (1 - ecc ** 2) / (1 + ecc * np.sin(np.pi / 2))) planet_lower_extent = -b - rp planet_upper_extent = -b + rp ax.axhline(planet_lower_extent, color='gray', ls='--') ax.axhline(planet_upper_extent, color='gray', ls='--') ax.axis('off') return ax
[docs] @jit def rotation_spectrum(self, t0_rot=0): """ Compute spectrophotometry during a rotation. Parameters ---------- t0_rot : float Zero-point in time for stellar rotation, default is zero Returns ------- time_series_spectrum : array Flux as a function of time and wavelength. Flux units are the same as the units for the input spectra. """ # handle the out-of-transit spectroscopic rotational modulation: ( spot_position_x, spot_position_y, spot_position_z, major_axis, minor_axis, angle, rad, contrast ) = self.spot_coords(t0_rot=t0_rot) rsq = spot_position_x ** 2 + spot_position_y ** 2 mu = jnp.sqrt(1 - rsq) f_S = rad ** 2 * mu * (spot_position_z < 0).astype(int) photosphere = (1 - f_S[..., 0].sum(axis=1)) * self.phot[None, :] spot_coverages, spot_spectra = jnp.broadcast_arrays( f_S[..., 0], self.spectrum[None, ...] ) time_series_spectrum = jnp.squeeze( # photospheric component: photosphere + # sum of the active region components: jnp.sum(spot_coverages * spot_spectra, axis=1) ) return time_series_spectrum
[docs] def bin_spectrum(spectrum, bins=None, log=True, min=None, max=None, **kwargs): """ Bin a spectrum, with log-spaced frequency bins. Parameters ---------- spectrum : `specutils.Spectrum1D` log : bool If true, compute bin edges based on the log base 10 of the frequency. bins : int or ~numpy.ndarray Number of bins, or the bin edges Returns ------- new_spectrum : """ nirspec_wl_range = (spectrum.wavelength > min) & (spectrum.wavelength < max) wavelength = spectrum.wavelength[nirspec_wl_range] flux = spectrum.flux[nirspec_wl_range] if log: wl_axis = np.log10(wavelength.to(u.um).value) else: wl_axis = wavelength.to(u.um).value # Bin the power spectrum: bs = binned_statistic( wl_axis, flux.value, statistic=lambda y: spectral_binning( y, all_x=wl_axis, all_y=flux.value ), bins=bins ) if log: wl_bins = 10 ** ( 0.5 * (bs.bin_edges[1:] + bs.bin_edges[:-1]) ) * u.um else: wl_bins = ( 0.5 * (bs.bin_edges[1:] + bs.bin_edges[:-1]) ) * u.um nans = np.isnan(bs.statistic) interp_fluxes = bs.statistic.copy() if np.any(nans) and all( map(lambda x: len(x) > 0, [wl_bins[nans], wl_bins[~nans], bs.statistic[~nans]]) ): interp_fluxes[nans] = np.interp(wl_bins[nans], wl_bins[~nans], bs.statistic[~nans]) return Spectrum1D( flux=interp_fluxes * flux.unit, spectral_axis=wl_bins, meta=spectrum.meta )
def spectral_binning(y, all_x, all_y): """ Spectral binning via trapezoidal approximation. """ min_ind = np.argwhere(all_y == y[0])[0, 0] max_ind = np.argwhere(all_y == y[-1])[0, 0] if max_ind > min_ind and y.shape == all_x[min_ind:max_ind + 1].shape: return np.trapz(y, all_x[min_ind:max_ind + 1]) / (all_x[max_ind] - all_x[min_ind]) return np.nan