A couple of years ago I made the jump from PyTorch to JAX.
Now, the skill of writing autodifferentiable code turns out to translate pretty smoothly between different frameworks. In this case, PyTorch and JAX really aren’t that different: replace
jax.numpy.foo(...) and you’re 95% of the way there! What about the other 5%? That’s the purpose of this article!
Assuming you already know PyTorch, this is what I’ve found that you need to know to get up to speed with JAX.
(This isn’t really an article to try to sell you on JAX. My assumption is that you already know you want to use it for its cool autoparallel / advanced autodiff / ludicrous speed / scientific ecosystem / etcetc. And I completely acknowledge that these awesome features come at the cost of something that has a slightly higher learning curve than PyTorch.)
Let’s get started. We’ll cover 9 bullet points in total.
1. JIT compile everything in one go
Most autodifferentiable programs will have a “numerical bit” (e.g. a training step for your model) and a “normal programming bit” (e.g. saving models to disk).
JAX makes this difference explicit. All the numerical work should go inside a single big JIT region, within which all numerical operations are compiled. A common example is a training loop:
@jax.jit def make_step(model, x, y): # Inside JIT region grads = compute_loss(model, x, y) model = stochastic_gradient_descent(model, grads) return model @jax.grad def compute_loss(model, x, y): # Called in `make_step`, so still inside the JIT region. ... def stochastic_gradient_descent(model, grads): # Likewise, also inside the JIT region. ... for step, (x, y) in zip(range(number_of_steps), dataloader): # Outside JIT region here. model = make_step(model, x, y) # Outside JIT region again here.
A common mistake would be to put
jax.jit on the
compute_loss function instead of the overall
make_step function. This would mean doing numerical work (the
stochastic_gradient_descent) outside of JIT. That would run, but without the JIT it would be unnecessarily slow.
See this RNN example as an example of good practice. The whole
make_step function is JIT compiled in one go.
jax.jits do nothing. Only the top-level JIT matters.
2. Understand tracing
When you call a JIT’d function, then (a) JAX replaces all your arrays with special “tracer” objects, and (b) runs your Python program, recording all the computations applied to those tracers. The recorded computation graph is then passed to the backend, then compiled, and then run.
This has an important implication. It means that you can’t do something like this:
@jax.jit def f(x): if x > 0: ...
x is being traced, so its value isn’t known at the time the Python code itself is being run. The Python code just sees a tracer object representing “any array with this shape and this dtype and this parallelism strategy”. (And it will recompile if any of these things changed, e.g. if you use a different batch size then the array shape changes.)
So, use things like Python
if statements if you want to control the computation graph that is being built. And use JAX-level statements like
jax.numpy.where if you want to branch on values inside the computation graph. These two different operations perform two different things, by design.
As another example: if you have any
jax.debug.print. Once again, these two operations are designed to do different things.
Remember, the Python code is just specifying a way to build a computation graph.
(Fun fact, as of PyTorch 2.0, then
torch.compile now does something very similar under-the-hood.)
3. Remember vmap
Remember that JAX has
jax.vmap! This simplifies a lot of things.
For example, instead of writing neural network layers that accept an array of shape
(batch, channels), you can write an easier-to-understand layer that accepts an array of shape
(channels,) – and then vmap it.
When you do something like this:
x = [array1, array2] @jax.jit def f(x): ... f(x)
then JAX knows to unpack the
x argument (which is a
list) in order to find the arrays
array2. This is how it can replace arrays with tracers when JIT compiling – and this unpacking is also how JAX can find the arrays to create gradients for when using
jax.grad, or the arrays to vectorise when using
jax.vmap, and so on.
The objects that JAX knows how to unpack are lists, tuples, dictionaries, user-registered custom nodes (we’ll come back to these), and any arbitrarily-nested collection of these. Overall, these are called “PyTrees”. It’s typical to represent a model as some PyTree of its parameters.
It’s almost always a mistake to use raw Python classes with JAX. These aren’t PyTrees, so JAX can’t know how you want to handle these unless you tell it how to. As such, we generally register classes as custom pytree nodes. (One way to do this is by subclassing
equinox.Module, see below.)
This is directly analogous to
torch.nn.Module, and how you must use
self.foo = torch.nn.ModuleList(...) rather than
self.foo = [...]. You can’t use raw Python classes with PyTorch either. The only important difference is that PyTorch Modules are treated as directed acyclic graphs (=the same Module can appear in multiple places), whilst JAX PyTrees are treated as, well, trees (=multiple appearances of the same object are treated as independent copies).
To get started with the practicalities of building your models as PyTrees, you might like to check out the Equinox libary on GitHub. Note that in general, the leaves of a PyTree can be things other than JAX arrays: you might also have activation functions, boolean flags, etc. Each Equinox model is a single big PyTree containing all of the above.
5. Gotcha: in-place updates
JAX uses a functional style. Instead of performing in-place updates like you might with PyTorch or NumPy:
some_array[2:5] = 0
you would write the following:
some_array = some_array.at[2:5].set(0)
Be careful if you’re using this together with autodifferentiation! The above might actually copy all of
some_array, then apply the update – and so not make the update in-place. The reason is that sometimes, an in-place update might would overwrite information needed for the backward pass. To avoid this, JAX makes a copy… but this can be silent performance footgun. (The compiler does do some analysis to try and avoid this if possible, but it isn’t perfect.)
You may sometimes prefer to write e.g.
some_array = jax.numpy.concatenate([some_array[:2], jax.numpy.zeros(3), some_array[5:]])
(As a comparison: PyTorch does allow you to overwrite information, but will throw an error on the backward pass if if thinks you’ve overwritten needed information. And this will sometimes throw false positives: it also uses an imperfect analysis tracked per-array, with a
JAX uses an explicit approach to generating randomn numbers: you have to provide a PRNG key. So instead of
torch.randn(shape), you have
jax.random.normal(key, shape). As such e.g. initialising the weights of a linear layer would look like:
def make_linear(in_size, out_size, key): wkey, bkey = jax.random.split(key, 2) weight = jax.random.normal(wkey, (out_size, in_size)) bias = jax.random.normal(bkey, (out_size,)) return (weight, bias) seed = 1234 key = jax.random.key(seed) model_key, some_other_key, another_key = jax.random.split(key, 3) parameters = make_linear(6, 32, model_key) # use the other two keys for whatever other purposes you have
(Of course in practice you’d use a neural network library, and not write this all out yourself.)
Don’t use the same random key twice unless you explicitly want to generate the same random numbers twice.
This is actually a JAX superpower! Explicitly threading the random state like this means you get trivially reproducible behaviour. (Something you have to work very hard for normally.)
7. Gotcha: call expensive functions as few times as possible
Something like this:
@jax.jit def f(x): for i in range(100): x = my_complicated_function(x) return x
can take a long time to compile.
The reason is that when JAX traces through this, it can’t see the
for loop. (All it sees are the operations applied to the tracers!) As a result you’ll get 100 independent copies of
my_complicated_function, which all get compiled separately.
If you ever get long compile times, then 99% of the time you’ve probably made some version of this mistake.
(In this case, a
jax.lax.scan is probably what you want. Likewise it’s usually also preferable to rewrite even simple stuff like
x2 = f(x1) x3 = f(x2)
as a little length-2 scan.)
vmap(jax.lax.cond) evaluates both branches
jax.numpy.where(pred, a, b) to do an
if statement between two arrays
b. This is like NumPy, and both arrays
b need to have been evaluated.
What if you don’t want to evaluate both branches? (Maybe they’re expensive to compute.) For this there is
jax.lax.cond(pred, if_fn, else_fn), which is the runtime equivalent of a Python
There is one “gotcha” here. If your computation is batched due to a
jax.vmap, then both
else_fn will be evaluated. Under-the-hood it is rewritten into a
jax.numpy.where(batch_pred, if_fn(...), else_fn(...)). After all, some batch elements might need one branch, and some batch elements might need the other.
So this makes sense for JAX’s programming model, but it can also be a bit surprising. E.g. if
if_fn sometimes produces an infinite loop for some inputs, and the
jax.lax.cond is to guard against that, then that infinite loop will still be caught when vmap’d! (In this case you could fix this by using another
jax.lax.cond to replace the bad input with a dummy-but-safe input before the loop.)
9. JAX libraries
JAX doesn’t try to bundle everything in one place. JAX itself only provides the underlying numerical operations of addition, matrix multiplication etc., and the transforms like autodiff, jit-compilation, etc. This is roughly equivalent to everything in
torch.*. Everything else is handled through libraries built on top of JAX.
An equivalent of
torch.optim is the Optax library.
An equivalent of
torch.nn is the Equinox library. (There are a couple of other neural network libraries as well – namely Flax and Haiku. All three of these libraries are used at Google, and maintained by Googlers! I’m linking to the one I’m most involved with, which is also the one that is most similar to PyTorch.)
Use jaxtyping for type annotations for the shape and dtype of arrays. These look like
from jaxtyping import Array, Float def some_function(image: Float[Array, "batch channels height width"]): ...
That’s it! If you know PyTorch and you’ve read this guide then you’re ready to get started using JAX.