Skip to content

Instantly share code, notes, and snippets.

@jakevdp
Last active January 12, 2023 01:57
Show Gist options
  • Save jakevdp/467da4f567d34c59c1f34559790ef85f to your computer and use it in GitHub Desktop.
Save jakevdp/467da4f567d34c59c1f34559790ef85f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"## JAX vmap\n",
"\n",
"This is the source material for a tweet thread I did recently: https://twitter.com/jakevdp/status/1612544608646606849\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/jakevdp/467da4f567d34c59c1f34559790ef85f)"
],
"metadata": {
"id": "4sEj41C8MWRj"
}
},
{
"cell_type": "markdown",
"source": [
"---\n",
"Let's talk about JAX's vmap! It's a transformation that can automatically create vectorized, batched versions of your functions... but what exactly it does is sometimes misunderstood. So let's dig-in!\n",
"\n",
"<img src=\"https://jax.readthedocs.io/en/latest/_static/jax_logo_250px.png\"/>\n",
"<font size=6>\n",
"\n",
"```python\n",
"from jax import vmap\n",
"```\n",
"\n",
"</font>\n"
],
"metadata": {
"id": "C-rrr7PTnWf3"
}
},
{
"cell_type": "markdown",
"source": [
"---\n",
"Suppose you've implemented a model that maps a vector input to a scalar output. As an example, here's a simple function similar to a single neuron in a neural net:"
],
"metadata": {
"id": "HSUIp2U_c-na"
}
},
{
"cell_type": "code",
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"rng = np.random.RandomState(8675309) # PRNGenny\n",
"W = rng.randn(3, 5) # weights\n",
"b = 1.0 # bias\n",
"\n",
"def model(v, W=W, b=b):\n",
" return jnp.tanh(W @ v + b).sum()"
],
"metadata": {
"id": "v_3lI5DxrCWL"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"---\n",
"This function accepts a single length-5 vector of inputs, and outputs a scalar:"
],
"metadata": {
"id": "-qoIeQhztUVS"
}
},
{
"cell_type": "code",
"source": [
"v = rng.randn(5)\n",
"print(model(v))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ydrO9URuuN3O",
"outputId": "124cef9d-4b3e-4d64-ec14-69bd73f491fd"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2.0699806\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"---\n",
"Now, suppose you want to apply this model across a 2D array, where each row of the array is an input. Passing this batched data directly leads to an error:"
],
"metadata": {
"id": "obtKROhinUnQ"
}
},
{
"cell_type": "code",
"source": [
"# This tells Jupyter to print one-line summaries of exceptions.\n",
"%xmode minimal"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RdETKwxQzmRa",
"outputId": "be02c84c-a159-442e-d1bb-362f51e03825"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Exception reporting mode: Minimal\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"v_batch = rng.randn(4, 5) # 4 batches\n",
"model(v_batch)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 97
},
"id": "6DUvyosYuw9i",
"outputId": "029301bb-da66-4f1e-d839-828a6abea55b"
},
"execution_count": 5,
"outputs": [
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 4 is different from 5)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"---\n",
"This error arises because our function is not defined in a way that can handle batched input. So what do we do? The easiest approach might be to use a simple Python list comprehension:"
],
"metadata": {
"id": "FDa9OAzMvAH7"
}
},
{
"cell_type": "code",
"source": [
"jnp.array([model(v) for v in v_batch])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NFWmiu3EvOSX",
"outputId": "e55cb1c7-fa44-4042-d2c1-8f7855e403d6"
},
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([-2.263083 , -1.4514356, 0.9401485, 2.9187164], dtype=float32)"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "markdown",
"source": [
"---\n",
"This works, of course, but if you're familiar with NumPy-style computing in Python you'll immediately recognize the problem: loops in Python are typically slow compared to the native vectorized operations offered by NumPy & JAX."
],
"metadata": {
"id": "GDc2O5BOu67p"
}
},
{
"cell_type": "markdown",
"source": [
"---\n",
"In the old days, you'd have to re-write your model to explicitly accept batched data. This sometimes takes some thought, for example here the simple matrix product becomes an Einstein summation:"
],
"metadata": {
"id": "d8fy5VFSv8Vs"
}
},
{
"cell_type": "code",
"source": [
"def batched_model(v_batch, W=W, b=b):\n",
" # Here are the dimensions for the batched matrix product:\n",
" # W: (m, k)\n",
" # v_batch: (n_batches, k)\n",
" # output: (n_batches, m)\n",
" return jnp.tanh(jnp.einsum(\"mk,nk->nm\", W, v_batch) + b).sum(1)\n",
"\n",
"# Results should match!\n",
"print(jnp.array([model(v) for v in v_batch]))\n",
"print(batched_model(v_batch))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GHZxQ4DlwIHE",
"outputId": "e4dd3798-17da-4215-e347-58e938d004cd"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[-2.263083 -1.4514356 0.9401485 2.9187164]\n",
"[-2.263083 -1.4514352 0.9401484 2.9187164]\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"---\n",
"As models get more complex, this sort of manual batchification can be complicated and error-prone. This is where jax.vmap comes in: it can transform your function into an efficient and correct batched version automatically!"
],
"metadata": {
"id": "aepo4NQHwt3H"
}
},
{
"cell_type": "code",
"source": [
"from jax import vmap\n",
"\n",
"print(batched_model(v_batch)) # manual batching\n",
"print(vmap(model)(v_batch)) # automatic batching!"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZrFT2m7DxxEK",
"outputId": "df9a3130-7f35-46ce-be91-524857889481"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[-2.263083 -1.4514352 0.9401484 2.9187164]\n",
"[-2.263083 -1.4514351 0.9401484 2.9187164]\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"---\n",
"You might ask now which approach is more efficient: surely vmap must come at a cost? In most cases, however, vmap will produce virtually identical operations as the manual implementation, which we can see by printing the jaxpr (JAX's internal function representation) for each."
],
"metadata": {
"id": "AzHxQrUkyAFV"
}
},
{
"cell_type": "code",
"source": [
"jax.make_jaxpr(batched_model)(v_batch)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uyhwf0NOzu2O",
"outputId": "a772259f-ddbb-4391-93e8-12b72e27ba9d"
},
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda a:f32[3,5]; b:f32[4,5]. let\n",
" c:f32[4,3] = xla_call[\n",
" call_jaxpr={ lambda ; d:f32[3,5] e:f32[4,5]. let\n",
" f:f32[4,3] = dot_general[\n",
" dimension_numbers=(((1,), (1,)), ((), ()))\n",
" precision=None\n",
" preferred_element_type=None\n",
" ] e d\n",
" in (f,) }\n",
" name=_einsum\n",
" ] a b\n",
" g:f32[4,3] = add c 1.0\n",
" h:f32[4,3] = tanh g\n",
" i:f32[4] = reduce_sum[axes=(1,)] h\n",
" in (i,) }"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"source": [
"jax.make_jaxpr(vmap(model))(v_batch)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QO4yt0ahywiB",
"outputId": "c82a8ea8-e38a-4e43-8ad8-9cc07792e5d7"
},
"execution_count": 11,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda a:f32[3,5]; b:f32[4,5]. let\n",
" c:f32[3,4] = dot_general[\n",
" dimension_numbers=(((1,), (1,)), ((), ()))\n",
" precision=None\n",
" preferred_element_type=None\n",
" ] a b\n",
" d:f32[3,4] = add c 1.0\n",
" e:f32[3,4] = tanh d\n",
" f:f32[4] = reduce_sum[axes=(0,)] e\n",
" in (f,) }"
]
},
"metadata": {},
"execution_count": 11
}
]
},
{
"cell_type": "markdown",
"source": [
"---\n",
"The details differ slightly — for example, xla_call comes from the fact that einsum is jit compiled — but the essential steps in the computation match more-or-less exactly: dot_general(), then add(), then tanh(), then reduce_sum().\n",
"\n",
"<pre>\n",
"{ lambda a:f32[3,5]; b:f32[4,5]. let { lambda a:f32[3,5]; b:f32[4,5]. let\n",
" c:f32[4,3] = xla_call[ c:f32[3,4] = <mark>dot_general</mark>[\n",
" call_jaxpr={ lambda ; d:f32[3,5] e:f32[4,5]. let dimension_numbers=(((1,), (1,)), ((), ()))\n",
" f:f32[4,3] = <mark>dot_general</mark>[ precision=None\n",
" dimension_numbers=(((1,), (1,)), ((), ())) preferred_element_type=None\n",
" precision=None ] a b\n",
" preferred_element_type=None d:f32[3,4] = <mark>add</mark> c 1.0\n",
" ] e d e:f32[3,4] = <mark>tanh</mark> d\n",
" in (f,) } f:f32[4] = <mark>reduce_sum</mark>[axes=(0,)] e\n",
" name=_einsum in (f,) }\n",
" ] a b\n",
" g:f32[4,3] = <mark>add</mark> c 1.0\n",
" h:f32[4,3] = <mark>tanh</mark> g\n",
" i:f32[4] = <mark>reduce_sum</mark>[axes=(1,)] h\n",
" in (i,) }\n",
"</pre>"
],
"metadata": {
"id": "hupLvslAz8o2"
}
},
{
"cell_type": "markdown",
"source": [
"---\n",
"And this is what jax.vmap gives you: a way to automatically create efficient batched versions of your functions – that will lower to fast vectorized computations – without having to re-write your code by hand.\n",
"\n",
"You can read more about vmap and related transforms in the JAX docs: https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html"
],
"metadata": {
"id": "yVbYunFrddch"
}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment