Skip to content

Instantly share code, notes, and snippets.

@davidwhogg
Last active June 25, 2022 18:43
Show Gist options
  • Save davidwhogg/ac3cea0f0d52338f2816d70b4ea88d4b to your computer and use it in GitHub Desktop.
Save davidwhogg/ac3cea0f0d52338f2816d70b4ea88d4b to your computer and use it in GitHub Desktop.
Testing components of Joaquin
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Experiments related to Joaquin\n",
"Technically, this notebook implements something *even dumber* than *Joaquin*.\n",
"It implements kNN in *Gaia*-only quantities to get a weighted-mean estimate of schmag.\n",
"\n",
"## Authors:\n",
"- **Adrian Price-Whelan** (Flatiron)\n",
"- **David W. Hogg** (NYU) (MPIA) (Flatiron)\n",
"\n",
"## Hyper-parameters:\n",
"- `ncoeff`: The number of BP and RP spectral coefficients to use.\n",
"- `maxk`: The maximum `k` to which we take neighbors.\n",
"- scalings or preprocessing of input features (currently null).\n",
"- how we use the neighbors (weighted mean, weighted linear fit, mixture of some kind?)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read in and munge all data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-06-15T01:40:27.141892Z",
"start_time": "2022-06-15T01:40:27.135575Z"
}
},
"outputs": [],
"source": [
"import pathlib\n",
"import astropy.table as at\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import numpy as np\n",
"import h5py\n",
"from tqdm import tqdm\n",
"from sklearn.neighbors import KDTree"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-06-15T01:22:16.673443Z",
"start_time": "2022-06-15T01:22:07.797965Z"
}
},
"outputs": [],
"source": [
"datadir = \"./\"\n",
"xm = at.Table.read(datadir + 'allStar-dr17-synspec-gaiadr3.fits')\n",
"xm2 = at.Table.read(datadir + 'allStar-dr17-synspec-gaiadr3-gaiasourcelite.fits')\n",
"xm2.rename_column('source_id', 'GAIADR3_SOURCE_ID')\n",
"allstar = at.Table.read(datadir + 'allStarLite-dr17-synspec_rev1.fits')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-06-15T17:47:13.865546Z",
"start_time": "2022-06-15T17:46:54.953972Z"
}
},
"outputs": [],
"source": [
"tbl = at.unique(at.hstack((allstar, xm)), keys='APOGEE_ID')\n",
"tbl = tbl[tbl['GAIADR3_SOURCE_ID'] != 0]\n",
"tbl = at.join(tbl, xm2, keys='GAIADR3_SOURCE_ID')\n",
"len(tbl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-06-15T17:47:15.631613Z",
"start_time": "2022-06-15T17:47:15.629159Z"
}
},
"outputs": [],
"source": [
"apogee_xp_cont_filename = pathlib.Path(datadir + 'apogee-dr17-xpcontinuous.hdf5')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-06-15T17:47:23.349365Z",
"start_time": "2022-06-15T17:47:16.202584Z"
}
},
"outputs": [],
"source": [
"# Read data and lightly rearrange\n",
"xp_tbl = at.Table()\n",
"with h5py.File(apogee_xp_cont_filename, 'r') as f:\n",
" xp_tbl['GAIADR3_SOURCE_ID'] = f['source_id'][:]\n",
" xp_tbl['bp'] = f['bp_coefficients'][:]\n",
" xp_tbl['rp'] = f['rp_coefficients'][:]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-06-15T17:47:29.937161Z",
"start_time": "2022-06-15T17:47:23.350804Z"
}
},
"outputs": [],
"source": [
"# Read data and make simple cuts\n",
"# Hogg: Why these cuts?\n",
"xp_apogee_tbl = at.join(tbl, xp_tbl, keys='GAIADR3_SOURCE_ID')\n",
"xp_apogee_tbl = xp_apogee_tbl[\n",
" (xp_apogee_tbl['TEFF'] > 3500.) &\n",
" (xp_apogee_tbl['TEFF'] < 6000.) &\n",
" (xp_apogee_tbl['LOGG'] > -0.5) &\n",
" (xp_apogee_tbl['LOGG'] < 5.0) &\n",
" (xp_apogee_tbl['M_H'] > -2.)\n",
"]\n",
"len(xp_apogee_tbl)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Make rectangular data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This does something useful!\n",
"xp_apogee_tbl = xp_apogee_tbl.filled()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-06-15T17:48:04.069510Z",
"start_time": "2022-06-15T17:48:03.887065Z"
}
},
"outputs": [],
"source": [
"# Make rectangular block of Gaia-only features (X) for training and validation\n",
"\n",
"# APW, HOGG: BUG: Why these cuts?\n",
"feature_mask = (\n",
" (xp_apogee_tbl['J'] < 13) &\n",
" (xp_apogee_tbl['H'] < 12) &\n",
" (xp_apogee_tbl['K'] < 11) &\n",
" (xp_apogee_tbl['AK_WISE'] > -0.1))\n",
"\n",
"ncoeff = 8 # MAGIC\n",
"features = np.hstack((\n",
" (xp_apogee_tbl['phot_bp_mean_mag'] - xp_apogee_tbl['phot_rp_mean_mag'])[feature_mask, None],\n",
" (xp_apogee_tbl['bp'][:, 1:ncoeff + 1] / xp_apogee_tbl['bp'][:, 0:1])[feature_mask],\n",
" (xp_apogee_tbl['rp'][:, 1:ncoeff + 1] / xp_apogee_tbl['rp'][:, 0:1])[feature_mask],\n",
"))\n",
"\n",
"feature_names = np.concatenate((\n",
" ['$BP-RP$ (mag)', ],\n",
" [f'BP[{i}]' for i in range(1, ncoeff + 1)],\n",
" [f'RP[{i}]' for i in range(1, ncoeff + 1)],\n",
"))\n",
"\n",
"print(features.shape)\n",
"print(len(feature_names), feature_names)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# rearrange feature order because Hogg has issues\n",
"\n",
"index = np.concatenate((\n",
" [0, ], \n",
" *([i, ncoeff + i, ] for i in range(1, ncoeff + 1))\n",
"))\n",
"print(feature_names[index])\n",
"\n",
"features = features[:, index]\n",
"feature_names = feature_names[index]\n",
"print(features.shape, feature_names.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Make list of labels (and label weights), aligned with the features.\n",
"\n",
"labels = (xp_apogee_tbl['parallax'] * 10 ** (1/5 * xp_apogee_tbl['phot_g_mean_mag']))[feature_mask]\n",
"print(labels.shape)\n",
"\n",
"label_errors = (xp_apogee_tbl['parallax_error'] * 10 ** (1/5 * xp_apogee_tbl['phot_g_mean_mag']))[feature_mask]\n",
"print(label_errors.shape)\n",
"\n",
"label_weights = 1. / (label_errors ** 2)\n",
"print(label_weights.shape)\n",
"\n",
"label_name = '$G$-band schmag (absmgy$^{-1/2}$)'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check that the labels aren't wack\n",
"\n",
"plt.scatter(labels, labels / label_errors, c=\"k\", s=1., alpha=0.05)\n",
"plt.axhline(np.median(labels / label_errors), color=\"k\")\n",
"plt.xlim(-300., 1000.)\n",
"plt.ylim(-10., 200.)\n",
"plt.xlabel(label_name)\n",
"plt.ylabel(\"label SNR\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check that the features aren't wack\n",
"\n",
"for i in range(features.shape[1]):\n",
" f = plt.figure()\n",
" foo = np.percentile(features[:, i], [2.5, 97.5])\n",
" lo = 0.5 * (foo[1] + foo[0]) - (foo[1] - foo[0])\n",
" hi = 0.5 * (foo[1] + foo[0]) + (foo[1] - foo[0])\n",
" plt.scatter(features[:, i], labels, c=\"k\", s=1., alpha=0.05)\n",
" plt.xlim(lo, hi)\n",
" plt.ylim(-300., 1500.)\n",
" plt.xlabel(feature_names[i])\n",
" plt.ylabel(label_name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Make training and validation samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# cut to eighths #MAGIC\n",
"# BUG: Should fix random state more sensibly than this.\n",
"\n",
"np.random.seed(17)\n",
"rando = np.random.randint(8, size=len(features))\n",
"train = rando != 0\n",
"valid = ~train\n",
"X_train, X_valid = features[train], features[valid]\n",
"Y_train, Y_valid = labels[train], labels[valid]\n",
"W_train, W_valid = label_weights[train], label_weights[valid]\n",
"print(X_train.shape, X_valid.shape)\n",
"print(Y_train.shape, Y_valid.shape)\n",
"print(W_train.shape, W_valid.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build a kNN model and validate it"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get all possibly useful validation-set neighbors up-front.\n",
"# We'll use them in various ways below.\n",
"\n",
"maxk = 128 # magic\n",
"tree = KDTree(X_train, leaf_size=32) # magic\n",
"dists, inds = tree.query(X_valid, k=maxk)\n",
"print(X_valid.shape, dists.shape, inds.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Let's look at a few objects\n",
"\n",
"for ii in range(8):\n",
" ff = plt.figure()\n",
" plt.axhline(Y_valid[ii], c=\"r\")\n",
" plt.errorbar(dists[ii], Y_train[inds[ii]], yerr = 1. / np.sqrt(W_train[inds[ii]]),\n",
" fmt=\"o\", color=\"k\", ecolor=\"k\")\n",
" plt.xlabel(\"distance to neighbor\")\n",
" plt.ylabel(\"label (schmag) of neighbor\")\n",
" plt.title(f\"validation-set object {ii}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test weighted mean as a function of k\n",
"# HOGG: TBD"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test a linear weighted least squares as a function of k\n",
"# HOGG: TBD"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test some kind of mixture model maybe??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run this model on EVERYTHING"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# APW: We need to figure out the above tests and then run in the data center."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment