Skip to content

Instantly share code, notes, and snippets.

@camriddell
Last active March 8, 2024 20:56
Show Gist options
  • Save camriddell/886d5cf0f268c88bf6955fe692f5281f to your computer and use it in GitHub Desktop.
Save camriddell/886d5cf0f268c88bf6955fe692f5281f to your computer and use it in GitHub Desktop.
A collection of univariate plots
from functools import partial
from textwrap import fill
from scipy.stats import norm, uniform, skewnorm, gaussian_kde, triang
from numpy import (
array, linspace, quantile, histogram, atleast_2d, mean, std, add
)
from numpy.lib.stride_tricks import sliding_window_view
from matplotlib.pyplot import subplots, show, rc
from matplotlib.axes import Axes
import seaborn as sns
rc('font', size=14)
rc('axes.spines', top=False, right=False, left=False, bottom=False)
dists = [
norm(loc=10, scale=2),
uniform(loc=0, scale=20),
skewnorm(a=6, loc=10, scale=2),
triang(c=1, loc=5, scale=7),
]
samples = [d.rvs(size=200, random_state=0) for d in dists]
def tufte_quartiles(ax, data):
q = quantile(data, [0, .25, .5, .75, 1])
ax.hlines([0, 0], [q[0], q[3]], [q[1], q[4]])
ax.scatter([q[2]], [0])
def color_density(ax, data):
grid = linspace(data.min(), data.max(), 400)
densities = gaussian_kde(data)(grid)
densities = atleast_2d(densities).repeat(2, axis=0)
ax.pcolormesh(grid, [0, 1], densities, cmap='Blues')
def point_decile(ax, data):
d = quantile(data, linspace(0, 1, 11))
linewidths = array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1]) * 4
bounds = sliding_window_view(d, 2)
ax.hlines(
[0] * len(bounds), bounds[:, 0], bounds[:, 1], linewidths=linewidths
)
ax.scatter(d[5], 0, color='white', zorder=7)
def point_multi_sigmas(ax, data):
linewidths = array([1, 2, 3, 2, 1]) * 5
avg, sd = mean(data), std(data)
sigmas = (sd * array([-3, -2, -1, 1, 2, 3]))
bounds = sliding_window_view(sigmas + avg, 2)
ax.hlines(
[0] * len(bounds), bounds[:, 0], bounds[:, 1], linewidths=linewidths
)
ax.scatter(avg, 0, color='white', zorder=7)
univariate_funcs = [
('strip', partial(sns.stripplot, jitter=.3, ec='white', size=3)),
('swarm', partial(sns.swarmplot, size=3)),
('rug', partial(Axes.eventplot, alpha=.4)),
('kernel density (area)', partial(sns.kdeplot, fill=True)),
('kernel density (color)', color_density),
('cumulative KDE', partial(sns.kdeplot, cumulative=True)),
('empirical CDF', sns.ecdfplot),
('histogram', partial(sns.histplot, bins='auto')),
('Box', sns.boxplot),
('Boxen', sns.boxenplot),
('Tufte Quartile', tufte_quartiles),
(r'Point $\bar{x}\pm\sigma$', partial(sns.pointplot, orient='h', errorbar='sd')),
('Point Deciles', point_decile),
(r'Point $\bar{x}\pm$ 3$\sigma$,2$\sigma$,1$\sigma$', point_multi_sigmas),
]
gridspec_kw = dict(hspace=.1, wspace=.02, left=.15, right=.9, bottom=.05)
fig, axes = subplots(
len(univariate_funcs) + 1, len(dists),
sharey='row', sharex='col',
figsize=(16, 12), gridspec_kw=gridspec_kw,
dpi=106
)
for ax, d in zip(axes[0], dists):
grid = linspace(*d.ppf([.001, .999]), 400)
y = d.pdf(grid)
ax.plot(grid, y)
ax.fill_between(grid, y, alpha=.4)
ax.set_title(
f"{d.dist.name.title()}\n"
f"{', '.join('='.join(map(str, t)) for t in d.kwds.items())}"
)
for i, (name, func) in enumerate(univariate_funcs, start=1):
if isinstance(func, partial):
func, args, kwargs = func.func, func.args, func.keywords
else:
args, kwargs = tuple(), {}
for j, s in enumerate(samples):
ax = axes[i, j]
package, _, _ = func.__module__.partition('.')
if package == 'seaborn':
func(x=s, ax=ax, **kwargs)
else:
func(ax, s, *args , **kwargs)
if ax in axes[:, 0]:
name = ' '.join(n if n.isupper() else n.capitalize() for n in name.split())
name = fill(name, width=20, break_long_words=False)
ax.set_ylabel(name, rotation=0, ha='right', va='center')
for ax in axes.flat:
ax.yaxis.set_tick_params(length=0, width=0, labelleft=False)
for ax in axes[:-1, :].flat:
ax.xaxis.set_tick_params(length=0, width=0, labelbottom=False)
header_bbox = axes[0, 0].get_position()
row_bbox = axes[1, 0].get_position()
from matplotlib.lines import Line2D
sepline = Line2D(
[.1, .9], [(header_bbox.y0 - row_bbox.y1) / 2 +row_bbox.y1] * 2,
color='k'
)
fig.add_artist(sepline)
gs = fig.axes[0].get_gridspec()
centered = (gs.right - gs.left) / 2 + gs.left
fig.text(
x=centered, y=.98, s='A Collection of Univariate Plots',
fontsize='xx-large', va='top', ha='center'
)
# show()
fig.savefig('univariateplots.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment