Skip to content

Instantly share code, notes, and snippets.

@andycasey
Created March 1, 2023 00:44
Show Gist options
  • Save andycasey/0a934fdd61465b7a4c6d413b642fb8d2 to your computer and use it in GitHub Desktop.
Save andycasey/0a934fdd61465b7a4c6d413b642fb8d2 to your computer and use it in GitHub Desktop.
from __future__ import annotations
import numpy as np
import warnings
from sklearn.decomposition._nmf import non_negative_factorization, _fit_multiplicative_update
from sklearn.exceptions import ConvergenceWarning
from astropy.nddata import InverseVariance
from typing import Optional, Union, Tuple, List
from specutils import SpectralAxis, Spectrum1D
class Continuum:
"""A base class to represent the stellar continuum."""
def __init__(
self,
spectral_axis: Optional[SpectralAxis] = None,
regions: Optional[List[Tuple[float, float]]] = None,
mask: Optional[Union[str, np.array]] = None,
fill_value: Optional[Union[int, float]] = np.nan,
):
"""
:param spectral_axis: [optional]
If given, the spectral axis of the spectrum that will be fitted. This is useful when using the same
class on many spectra where the spectral axis is the same.
:param regions: [optional]
A list of two-length tuples of the form (lower, upper) in the same units as the spectral axis.
:param mask: [optional]
A boolean array of the same length as the spectral axis, where False indicates a continuum pixel,
and True indicates a pixel to be masked in the continuum fit.
:param fill_value: [optional]
The value to use for pixels where the continuum is not defined.
"""
if isinstance(mask, str):
self.mask = np.load(mask)
else:
self.mask = mask
self.regions = regions
self.fill_value = fill_value
self.spectral_axis = spectral_axis
return None
def _initialize(self, spectrum):
try:
self._initialized_args
except AttributeError:
self._initialized_args = _pixel_slice_and_mask(
spectrum.wavelength, self.regions, self.mask
)
finally:
return self._initialized_args
@property
def num_regions(self):
"""
Return the number of regions used to fit the continuum.
"""
return 1 if self.regions is None else len(self.regions)
def _get_shape(self, spectrum: Spectrum1D):
"""
Get the shape of the spectrum.
:param spectrum:
A spectrum, which could be a 1D spectrum or multiple spectra with the same spectral axis.
"""
try:
N, P = spectrum.flux.shape
except:
N, P = (1, spectrum.flux.size)
return (N, P)
def fit(self, spectrum: Spectrum1D) -> Continuum:
"""
Fit the continuum in the given spectrum.
:param spectrum:
A spectrum.
"""
raise NotImplementedError("This should be implemented by the sub-classes")
def __call__(
self, spectrum: Spectrum1D, theta: Optional[Union[List, np.array, Tuple]] = None
) -> np.ndarray:
"""
Return the estimated continuum given a spectrum and parameters.
:param spectrum:
A spectrum.
:param theta: [optional]
A set of parameters for the continuum. If not provided, this defaults to the parameters
previous fit to the spectrum.
"""
raise NotImplementedError("This should be implemented by the sub-classes")
def _pixel_slice_and_mask(
spectral_axis: SpectralAxis,
regions: Optional[List[Tuple[float, float]]] = None,
mask: Optional[np.array] = None,
):
"""
Return region slices in pixel space, and the continuum masks to use in each region.
:param spectral_axis:
The spectral axis of the spectrum.
:param regions:
A list of two-length tuples of the form (lower, upper) in the same units as the spectral axis.
:param mask:
A boolean array of the same length as the spectral axis, where False indicates a continuum pixel,
and True indicates a pixel to be masked in the continuum fit.
:returns:
A tuple of two lists, the first containing the pixel slices for each region, and the second containing
the continuum mask for each region.
"""
if regions is None:
region_slices = [(0, spectral_axis.size)]
else:
region_slices = []
for lower, upper in regions:
# TODO: allow for units/quantities in (lower, upper)?
region_slices.append(spectral_axis.value.searchsorted([lower, upper]))
region_masks = []
if mask is None:
for lower, upper in region_slices:
# No mask, keep all pixels as continuum.
region_masks.append(np.arange(lower, upper, dtype=int))
else:
if len(mask) != len(spectral_axis):
# Assume mask is a list of regions to mask.
constructed_mask = np.zeros(len(spectral_axis), dtype=bool)
for lower, upper in mask:
idx_lower, idx_higher = np.clip(
spectral_axis.value.searchsorted([lower, upper]) - 1,
0,
spectral_axis.size - 1
)
constructed_mask[idx_lower:idx_higher] = True
for lower, upper in region_slices:
# Mask given, exclude those masked.
region_masks.append(np.where(~constructed_mask[lower:upper])[0] + lower)
else:
for lower, upper in region_slices:
# Mask given, exclude those masked.
region_masks.append(np.where(~mask[lower:upper])[0] + lower)
return (region_slices, region_masks)
class Sinusoids(Continuum):
"""Represent the stellar continuum with sine and cosine functions."""
def __init__(
self,
deg: Optional[int] = 3,
L: Optional[float] = 1400,
scalar: Optional[float] = 1e-6,
spectral_axis: Optional[SpectralAxis] = None,
regions: Optional[List[Tuple[float, float]]] = None,
mask: Optional[np.array] = None,
fill_value: Optional[Union[int, float]] = np.nan,
**kwargs,
) -> None:
(
"""
:param deg: [optional]
The degree of sinusoids to include.
:param L: [optional]
The length scale for the sines and cosines.
"""
+ Continuum.__init__.__doc__
)
super(Sinusoids, self).__init__(
spectral_axis=spectral_axis,
regions=regions,
mask=mask,
fill_value=fill_value,
**kwargs,
)
self.deg = int(deg)
self.L = float(L)
self.scalar = float(scalar)
return None
def fit(self, spectrum: Spectrum1D) -> Sinusoids:
_initialized_args = self._initialize(spectrum)
N, P, *_ = _initialized_args
all_flux = spectrum.flux.value.reshape((N, P))
all_ivar = spectrum.uncertainty.represent_as(InverseVariance).array.reshape(
(N, P)
)
self.theta = self._fit(all_flux, all_ivar, _initialized_args)
return self
def _fit(self, flux, ivar, _initialized_args):
N, P, *region_args = _initialized_args
theta = np.empty((N, self.num_regions, 2 * self.deg + 1))
for i, (flux_, ivar_) in enumerate(zip(flux, ivar)):
for j, (_, indices, _, M_continuum) in enumerate(zip(*region_args)):
MTM = M_continuum @ (ivar_[indices][:, None] * M_continuum.T)
MTy = M_continuum @ (ivar_[indices] * flux_[indices]).T
eigenvalues = np.linalg.eigvalsh(MTM)
MTM[np.diag_indices(len(MTM))] += self.scalar * np.max(eigenvalues)
# eigenvalues = np.linalg.eigvalsh(MTM)
# condition_number = max(eigenvalues) / min(eigenvalues)
# TODO: warn on high condition number
if np.all(ivar_[indices] == 0):
print(f"Region {j} is empty. Setting theta to zero.")
theta[i, j] = 0.0
else:
theta[i, j] = np.linalg.solve(MTM, MTy)
return theta
def _evaluate(self, theta, initialized_args):
N, P, *region_args = initialized_args
continuum = self.fill_value * np.ones((N, P))
for i in range(N):
for j, ((lower, upper), _, M_region, _) in enumerate(zip(*region_args)):
continuum[i, slice(lower, upper)] = M_region.T @ theta[i, j]
return continuum
def __call__(
self, spectrum: Spectrum1D, theta: Optional[Union[List, np.array, Tuple]] = None, **kwargs
) -> np.ndarray:
if theta is None:
theta = self.theta
_initialized_args = self._initialize(spectrum)
return self._evaluate(theta, _initialized_args)
def _initialize(self, spectrum: Spectrum1D):
try:
self._initialized_args
except AttributeError:
N, P = self._get_shape(spectrum)
region_slices, region_continuum_indices = _pixel_slice_and_mask(
spectrum.wavelength, self.regions, self.mask
)
# Create the design matrices.
M_region = []
M_continuum = []
for (lower, upper), indices in zip(region_slices, region_continuum_indices):
region_pixels = spectrum.wavelength.value[slice(lower, upper)]
region_continuum = spectrum.wavelength.value[indices]
M_region.append(self._design_matrix(region_pixels))
M_continuum.append(self._design_matrix(region_continuum))
self._initialized_args = (
N,
P,
region_slices,
region_continuum_indices,
M_region,
M_continuum,
)
else:
# If we already have initialized arguments, we just need to update the shape.
N, P = self._get_shape(spectrum)
_N, _P, *region_args = self._initialized_args
self._initialized_args = (N, P, *region_args)
finally:
return self._initialized_args
def _design_matrix(self, dispersion: np.array) -> np.array:
scale = 2 * (np.pi / self.L)
return np.vstack(
[
np.ones_like(dispersion).reshape((1, -1)),
np.array(
[
[np.cos(o * scale * dispersion), np.sin(o * scale * dispersion)]
for o in range(1, self.deg + 1)
]
).reshape((2 * self.deg, dispersion.size)),
]
)
class Emulator:
def __init__(
self,
components: np.ndarray,
alpha_W: Optional[float] = 1e-5,
nmf_solver: Optional[str] = "mu",
nmf_max_iter: Optional[int] = 100,
nmf_tol: Optional[float] = 1e-1,
deg: Optional[int] = 3,
L: Optional[float] = 1400,
scalar: Optional[float] = 1e-6,
spectral_axis: Optional[SpectralAxis] = None,
regions: Optional[List[Tuple[float, float]]] = None,
mask: Optional[np.array] = None,
fill_value: Optional[Union[int, float]] = np.nan,
**kwargs,
) -> None:
self.continuum_model = Sinusoids(
spectral_axis=spectral_axis,
regions=regions,
mask=mask,
fill_value=fill_value,
deg=deg,
L=L,
scalar=scalar
)
self.mask = mask
self.alpha_W = alpha_W
self.nmf_solver = nmf_solver
self.nmf_max_iter = nmf_max_iter
self.nmf_tol = nmf_tol
self.components = components
self.phi_size = components.shape[0]
self.theta_size = self.continuum_model.num_regions * (2 * self.continuum_model.deg + 1)
return None
def _check_data(self, spectrum):
N, P = self.continuum_model._get_shape(spectrum)
flux = spectrum.flux.value.reshape((N, P)).copy()
ivar = spectrum.uncertainty.represent_as(InverseVariance).array.reshape((N, P)).copy()
bad_pixels = ~np.isfinite(ivar) | ~np.isfinite(flux) | (ivar == 0)
flux[bad_pixels] = 0
ivar[bad_pixels] = 0
return (flux, ivar)
def _maximization(self, flux, ivar, continuum_args):
theta = self.continuum_model._fit(flux, ivar, continuum_args)
continuum = self.continuum_model._evaluate(theta, continuum_args)
return (theta, continuum)
def _expectation(self, flux, W, **kwargs):
absorption = 1 - flux # absorption
# Only use non-negative finite pixels.
use = np.isfinite(flux) & (absorption >= 0)
n_components, n_pixels = self.components.shape
assert flux.size == n_pixels
if self.mask is not None:
use *= ~self.mask
if np.sum(use) < n_components:
print(f"Number of non-negative finite pixels ({np.sum(use)}) is less than the number of components ({n_components}).")
X = absorption[use].reshape((1, -1))
H = self.components[:, use]
# TODO: Scale alpha_W based on the number of pixels being used in the mask?
kwds = dict(
# If W is None it means it's the first iteration.
init=None if W is None else "custom",
# The documentation says that custom matrices W and H can only be used if `update_H=True`.
# Since we want it to use W from the previous iteration, we will set `update_H=True`, and ignore H_adjusted.
update_H=True,
solver=self.nmf_solver,
W=W,
H=H,
n_components=n_components,
beta_loss="frobenius",
tol=self.nmf_tol,
max_iter=self.nmf_max_iter,
# Only regularization on W, because we are at the test step here.
alpha_W=self.alpha_W,
alpha_H=0.0,
l1_ratio=1.0,
random_state=None,
verbose=0,
shuffle=False
)
# Only include kwargs that non_negative_factorization accepts.
kwds.update({k: v for k, v in kwargs.items() if k in kwds})
W_next, H_adjusted_and_masked, n_iter = non_negative_factorization(X, **kwds)
#H_adjusted = np.zeros(self.components.shape, dtype=float)
#H_adjusted[:, use] = H_adjusted_and_masked
#if adjusted:
# use_H = H_adjusted
#else:
use_H = self.components
rectified_model_flux = 1 - (W_next @ use_H)[0]
return (W_next, rectified_model_flux, np.sum(use), n_iter)
def fit(self, spectrum: Spectrum1D, tol: float = 1e-1, max_iter: int = 1_000):
"""
Simultaneously fit the continuum and stellar absorption.
:param spectrum:
The spectrum to fit continuum to. This can be multiple visits of the same spectrum.
:param tol: [optional]
The difference in \chi-squared value between iterations to establish convergence.
What makes a good tolerance value? There are two parts that contribute to this
tolerance: the stellar absorption model, and the continuum model. The continuum
model is linear algebra, so it contributes very little to the tolerance between
successive iterations. The stellar absorption model is a non-negative matrix
factorization. This tolerance definitely should not be set to be smaller than
the tolerance specified when building the non-negative matrix factorization (1e-4),
because the stellar absorption model has no flexibility to predict absorption
better than that average degree. For this reason, it's probably sufficient to
set the tolerance a few orders of magnitude larger (1e-1 or 1e-2), with some
sensible number of max iterations.
:param max_iter: [optional]
The maximum number of expectation-maximization iterations.
:returns:
A tuple of (phi, theta, continuum, model_rectified_flux, meta) where:
- `phi` are the amplitudes for the non-negative matrix factorization (e.g., `W`)
- `theta` is the parameters of the continuum model
- `continuum` is the continuum model evaluated at the spectral axis
- `model_rectified_flux` is the rectified flux evaluated at the spectral axis
- `meta` is a dictionary of metadata
"""
try:
return self._fit(spectrum, tol, max_iter)
except:
raise
N, P = self.continuum_model._get_shape(spectrum)
phi = np.zeros(self.components.shape[0])
theta = np.zeros((N, self.theta_size))
continuum = np.ones((N, P)) * np.nan
model_rectified_flux = np.ones(P) * np.nan
meta = dict(
chi_sqs=[999],
reduced_chi_sqs=[999],
n_pixels_used_in_nmf=0,
success=False,
iter=1000,
message="Failed to fit.",
continuum_args=None
)
return (phi, theta, continuum, model_rectified_flux, meta)
def _expectation_maximization(self, flux, ivar, stacked_flux, phi, continuum_args, **kwargs):
phi_next, model_rectified_flux, n_pixels, n_nmf_iter = self._expectation(
stacked_flux,
W=phi.copy() if phi is not None else phi, # make sure you copy
**kwargs
)
theta_next, continuum = self._maximization(
flux / model_rectified_flux,
model_rectified_flux * ivar * model_rectified_flux,
continuum_args
)
chi_sq = ((flux - model_rectified_flux * continuum)**2 * ivar)
finite = np.isfinite(chi_sq)
chi_sq = np.sum(chi_sq[finite])
args = (phi_next, theta_next, continuum, model_rectified_flux, n_pixels, np.sum(finite), n_nmf_iter)
return (chi_sq, args)
def _fit(self, spectrum, tol, max_iter):
flux, ivar = self._check_data(spectrum)
continuum_args = self.continuum_model._initialize(spectrum)
with warnings.catch_warnings():
for category in (RuntimeWarning, ConvergenceWarning):
warnings.filterwarnings("ignore", category=category)
phi = None # phi is the same as W used in NMF
theta, continuum = self._maximization(flux, ivar, continuum_args)
# initial trick
#continuum *= 1.5
#print("doing a hack")
chi_sqs, n_pixels_used_in_chisq, n_pixels_used_in_nmf, n_nmf_iters = ([], [], [], [])
for iter in range(max_iter):
conditional_flux = flux / continuum
conditional_ivar = continuum * flux * continuum
stacked_flux = np.sum(conditional_flux * conditional_ivar, axis=0) / np.sum(conditional_ivar, axis=0)
chi_sq, em_args = self._expectation_maximization(
flux,
ivar,
stacked_flux,
phi, #phi, #None, # phi
continuum_args,
)
if iter > 0:
assert phi is not None
if iter > 0 and (chi_sq > chi_sqs[-1]):
print(f"Failed to improve \chi^2")
success, message = (False, "Failed to improve \chi^2")
break
(phi, theta, continuum, model_rectified_flux, n_pixels, n_finite, n_nmf_iter) = em_args
chi_sqs.append(chi_sq)
n_pixels_used_in_nmf.append(n_pixels)
n_pixels_used_in_chisq.append(n_finite)
n_nmf_iters.append(n_nmf_iter)
if iter > 0:
delta_chi_sq = chi_sqs[-1] - chi_sqs[-2]
if (delta_chi_sq < 0) and abs(delta_chi_sq) <= tol:
# Converged
success, message = (True, f"Convergence reached after {iter} iterations")
break
else:
success, message = (True, f"Convergence not reached after {max_iter} iterations ({abs(delta_chi_sq)} > {tol:.2e})")
warnings.warn(message)
reduced_chi_sqs = np.array(chi_sqs) / (np.array(n_pixels_used_in_chisq) - phi.size - theta.size - 1)
meta = dict(
chi_sqs=chi_sqs,
reduced_chi_sqs=reduced_chi_sqs,
n_pixels_used_in_nmf=n_pixels_used_in_nmf,
success=success,
iter=iter,
message=message,
continuum_args=continuum_args,
n_nmf_iters=n_nmf_iters
)
return (phi, theta, continuum, model_rectified_flux, meta)
if __name__ == "__main__":
import pickle
from astropy import units as u
with open("/uufs/chpc.utah.edu/common/home/sdss50/sdsswork/mwm/spectro/astra/component_data/continuum/sgGK_200921nlte_nmf_components.pkl", "rb") as fp:
components = pickle.load(fp)
with open("/uufs/chpc.utah.edu/common/home/sdss50/sdsswork/mwm/spectro/astra/component_data/continuum/20230222_sky_mask_ivar_scalar.pkl", "rb") as f:
ivar_scalar = pickle.load(f)
emulator = Emulator(
components,
regions=[
(15_100.0, 15_800.0),
(15_840.0, 16_417.0),
(16_500.0, 17_000.0)
],
mask=(ivar_scalar != 1)
)
flux_unit = u.Unit("1e-17 erg / (Angstrom cm2 s)")
wavelength = 10**(4.179 + 6e-6 * np.arange(8575))
# load flux, ivar
spectrum = Spectrum1D(
spectral_axis=u.Quantity(wavelength, unit=u.Angstrom),
flux=flux * flux_unit,
uncertainty=InverseVariance(ivar)
)
phi, theta, continuum, model_rectified_flux, meta = emulator.fit(spectrum)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment