Last active
June 25, 2022 18:43
-
-
Save davidwhogg/ac3cea0f0d52338f2816d70b4ea88d4b to your computer and use it in GitHub Desktop.
Testing components of Joaquin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Experiments related to Joaquin\n", | |
"Technically, this notebook implements something *even dumber* than *Joaquin*.\n", | |
"It implements kNN in *Gaia*-only quantities to get a weighted-mean estimate of schmag.\n", | |
"\n", | |
"## Authors:\n", | |
"- **Adrian Price-Whelan** (Flatiron)\n", | |
"- **David W. Hogg** (NYU) (MPIA) (Flatiron)\n", | |
"\n", | |
"## Hyper-parameters:\n", | |
"- `ncoeff`: The number of BP and RP spectral coefficients to use.\n", | |
"- `maxk`: The maximum `k` to which we take neighbors.\n", | |
"- scalings or preprocessing of input features (currently null).\n", | |
"- how we use the neighbors (weighted mean, weighted linear fit, mixture of some kind?)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Read in and munge all data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-06-15T01:40:27.141892Z", | |
"start_time": "2022-06-15T01:40:27.135575Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"import pathlib\n", | |
"import astropy.table as at\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline\n", | |
"import numpy as np\n", | |
"import h5py\n", | |
"from tqdm import tqdm\n", | |
"from sklearn.neighbors import KDTree" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-06-15T01:22:16.673443Z", | |
"start_time": "2022-06-15T01:22:07.797965Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"datadir = \"./\"\n", | |
"xm = at.Table.read(datadir + 'allStar-dr17-synspec-gaiadr3.fits')\n", | |
"xm2 = at.Table.read(datadir + 'allStar-dr17-synspec-gaiadr3-gaiasourcelite.fits')\n", | |
"xm2.rename_column('source_id', 'GAIADR3_SOURCE_ID')\n", | |
"allstar = at.Table.read(datadir + 'allStarLite-dr17-synspec_rev1.fits')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-06-15T17:47:13.865546Z", | |
"start_time": "2022-06-15T17:46:54.953972Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"tbl = at.unique(at.hstack((allstar, xm)), keys='APOGEE_ID')\n", | |
"tbl = tbl[tbl['GAIADR3_SOURCE_ID'] != 0]\n", | |
"tbl = at.join(tbl, xm2, keys='GAIADR3_SOURCE_ID')\n", | |
"len(tbl)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-06-15T17:47:15.631613Z", | |
"start_time": "2022-06-15T17:47:15.629159Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"apogee_xp_cont_filename = pathlib.Path(datadir + 'apogee-dr17-xpcontinuous.hdf5')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-06-15T17:47:23.349365Z", | |
"start_time": "2022-06-15T17:47:16.202584Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"# Read data and lightly rearrange\n", | |
"xp_tbl = at.Table()\n", | |
"with h5py.File(apogee_xp_cont_filename, 'r') as f:\n", | |
" xp_tbl['GAIADR3_SOURCE_ID'] = f['source_id'][:]\n", | |
" xp_tbl['bp'] = f['bp_coefficients'][:]\n", | |
" xp_tbl['rp'] = f['rp_coefficients'][:]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-06-15T17:47:29.937161Z", | |
"start_time": "2022-06-15T17:47:23.350804Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"# Read data and make simple cuts\n", | |
"# Hogg: Why these cuts?\n", | |
"xp_apogee_tbl = at.join(tbl, xp_tbl, keys='GAIADR3_SOURCE_ID')\n", | |
"xp_apogee_tbl = xp_apogee_tbl[\n", | |
" (xp_apogee_tbl['TEFF'] > 3500.) &\n", | |
" (xp_apogee_tbl['TEFF'] < 6000.) &\n", | |
" (xp_apogee_tbl['LOGG'] > -0.5) &\n", | |
" (xp_apogee_tbl['LOGG'] < 5.0) &\n", | |
" (xp_apogee_tbl['M_H'] > -2.)\n", | |
"]\n", | |
"len(xp_apogee_tbl)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Make rectangular data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# This does something useful!\n", | |
"xp_apogee_tbl = xp_apogee_tbl.filled()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-06-15T17:48:04.069510Z", | |
"start_time": "2022-06-15T17:48:03.887065Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"# Make rectangular block of Gaia-only features (X) for training and validation\n", | |
"\n", | |
"# APW, HOGG: BUG: Why these cuts?\n", | |
"feature_mask = (\n", | |
" (xp_apogee_tbl['J'] < 13) &\n", | |
" (xp_apogee_tbl['H'] < 12) &\n", | |
" (xp_apogee_tbl['K'] < 11) &\n", | |
" (xp_apogee_tbl['AK_WISE'] > -0.1))\n", | |
"\n", | |
"ncoeff = 8 # MAGIC\n", | |
"features = np.hstack((\n", | |
" (xp_apogee_tbl['phot_bp_mean_mag'] - xp_apogee_tbl['phot_rp_mean_mag'])[feature_mask, None],\n", | |
" (xp_apogee_tbl['bp'][:, 1:ncoeff + 1] / xp_apogee_tbl['bp'][:, 0:1])[feature_mask],\n", | |
" (xp_apogee_tbl['rp'][:, 1:ncoeff + 1] / xp_apogee_tbl['rp'][:, 0:1])[feature_mask],\n", | |
"))\n", | |
"\n", | |
"feature_names = np.concatenate((\n", | |
" ['$BP-RP$ (mag)', ],\n", | |
" [f'BP[{i}]' for i in range(1, ncoeff + 1)],\n", | |
" [f'RP[{i}]' for i in range(1, ncoeff + 1)],\n", | |
"))\n", | |
"\n", | |
"print(features.shape)\n", | |
"print(len(feature_names), feature_names)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# rearrange feature order because Hogg has issues\n", | |
"\n", | |
"index = np.concatenate((\n", | |
" [0, ], \n", | |
" *([i, ncoeff + i, ] for i in range(1, ncoeff + 1))\n", | |
"))\n", | |
"print(feature_names[index])\n", | |
"\n", | |
"features = features[:, index]\n", | |
"feature_names = feature_names[index]\n", | |
"print(features.shape, feature_names.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Make list of labels (and label weights), aligned with the features.\n", | |
"\n", | |
"labels = (xp_apogee_tbl['parallax'] * 10 ** (1/5 * xp_apogee_tbl['phot_g_mean_mag']))[feature_mask]\n", | |
"print(labels.shape)\n", | |
"\n", | |
"label_errors = (xp_apogee_tbl['parallax_error'] * 10 ** (1/5 * xp_apogee_tbl['phot_g_mean_mag']))[feature_mask]\n", | |
"print(label_errors.shape)\n", | |
"\n", | |
"label_weights = 1. / (label_errors ** 2)\n", | |
"print(label_weights.shape)\n", | |
"\n", | |
"label_name = '$G$-band schmag (absmgy$^{-1/2}$)'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# check that the labels aren't wack\n", | |
"\n", | |
"plt.scatter(labels, labels / label_errors, c=\"k\", s=1., alpha=0.05)\n", | |
"plt.axhline(np.median(labels / label_errors), color=\"k\")\n", | |
"plt.xlim(-300., 1000.)\n", | |
"plt.ylim(-10., 200.)\n", | |
"plt.xlabel(label_name)\n", | |
"plt.ylabel(\"label SNR\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# check that the features aren't wack\n", | |
"\n", | |
"for i in range(features.shape[1]):\n", | |
" f = plt.figure()\n", | |
" foo = np.percentile(features[:, i], [2.5, 97.5])\n", | |
" lo = 0.5 * (foo[1] + foo[0]) - (foo[1] - foo[0])\n", | |
" hi = 0.5 * (foo[1] + foo[0]) + (foo[1] - foo[0])\n", | |
" plt.scatter(features[:, i], labels, c=\"k\", s=1., alpha=0.05)\n", | |
" plt.xlim(lo, hi)\n", | |
" plt.ylim(-300., 1500.)\n", | |
" plt.xlabel(feature_names[i])\n", | |
" plt.ylabel(label_name)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Make training and validation samples" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# cut to eighths #MAGIC\n", | |
"# BUG: Should fix random state more sensibly than this.\n", | |
"\n", | |
"np.random.seed(17)\n", | |
"rando = np.random.randint(8, size=len(features))\n", | |
"train = rando != 0\n", | |
"valid = ~train\n", | |
"X_train, X_valid = features[train], features[valid]\n", | |
"Y_train, Y_valid = labels[train], labels[valid]\n", | |
"W_train, W_valid = label_weights[train], label_weights[valid]\n", | |
"print(X_train.shape, X_valid.shape)\n", | |
"print(Y_train.shape, Y_valid.shape)\n", | |
"print(W_train.shape, W_valid.shape)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Build a kNN model and validate it" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Get all possibly useful validation-set neighbors up-front.\n", | |
"# We'll use them in various ways below.\n", | |
"\n", | |
"maxk = 128 # magic\n", | |
"tree = KDTree(X_train, leaf_size=32) # magic\n", | |
"dists, inds = tree.query(X_valid, k=maxk)\n", | |
"print(X_valid.shape, dists.shape, inds.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Let's look at a few objects\n", | |
"\n", | |
"for ii in range(8):\n", | |
" ff = plt.figure()\n", | |
" plt.axhline(Y_valid[ii], c=\"r\")\n", | |
" plt.errorbar(dists[ii], Y_train[inds[ii]], yerr = 1. / np.sqrt(W_train[inds[ii]]),\n", | |
" fmt=\"o\", color=\"k\", ecolor=\"k\")\n", | |
" plt.xlabel(\"distance to neighbor\")\n", | |
" plt.ylabel(\"label (schmag) of neighbor\")\n", | |
" plt.title(f\"validation-set object {ii}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Test weighted mean as a function of k\n", | |
"# HOGG: TBD" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Test a linear weighted least squares as a function of k\n", | |
"# HOGG: TBD" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Test some kind of mixture model maybe??" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Run this model on EVERYTHING" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# APW: We need to figure out the above tests and then run in the data center." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.4" | |
}, | |
"toc": { | |
"base_numbering": 1, | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment