Learning JAX as a PyTorch developer

< Back to "Thoughts" | Posted on November 9, 2023

A couple of years ago I made the jump from PyTorch to JAX.

Now, the skill of writing autodifferentiable code turns out to translate pretty smoothly between different frameworks. In this case, PyTorch and JAX really aren’t that different: replace torch.foo(...) with jax.numpy.foo(...) and you’re 95% of the way there! What about the other 5%? That’s the purpose of this article!

Assuming you already know PyTorch, this is what I’ve found that you need to know to get up to speed with JAX.

(This isn’t really an article to try to sell you on JAX. My assumption is that you already know you want to use it for its cool autoparallel / advanced autodiff / ludicrous speed / scientific ecosystem / etcetc. And I completely acknowledge that these awesome features come at the cost of something that has a slightly higher learning curve than PyTorch.)

Let’s get started. We’ll cover 9 bullet points in total.

1. JIT compile everything in one go

Most autodifferentiable programs will have a “numerical bit” (e.g. a training step for your model) and a “normal programming bit” (e.g. saving models to disk).

JAX makes this difference explicit. All the numerical work should go inside a single big JIT region, within which all numerical operations are compiled. A common example is a training loop:

def make_step(model, x, y):
    # Inside JIT region
    grads = compute_loss(model, x, y)
    model = stochastic_gradient_descent(model, grads)
    return model

def compute_loss(model, x, y):
    # Called in `make_step`, so still inside the JIT region.
def stochastic_gradient_descent(model, grads):
    # Likewise, also inside the JIT region.

for step, (x, y) in zip(range(number_of_steps), dataloader):
    # Outside JIT region here.
    model = make_step(model, x, y)
    # Outside JIT region again here.

A common mistake would be to put jax.jit on the compute_loss function instead of the overall make_step function. This would mean doing numerical work (the stochastic_gradient_descent) outside of JIT. That would run, but without the JIT it would be unnecessarily slow.

See this RNN example as an example of good practice. The whole make_step function is JIT compiled in one go.

Nested jax.jits do nothing. Only the top-level JIT matters.

2. Understand tracing

When you call a JIT’d function, then (a) JAX replaces all your arrays with special “tracer” objects, and (b) runs your Python program, recording all the computations applied to those tracers. The recorded computation graph is then passed to the backend, then compiled, and then run.

This has an important implication. It means that you can’t do something like this:

def f(x):
    if x > 0:

because x is being traced, so its value isn’t known at the time the Python code itself is being run. The Python code just sees a tracer object representing “any array with this shape and this dtype and this parallelism strategy”. (And it will recompile if any of these things changed, e.g. if you use a different batch size then the array shape changes.)

So, use things like Python if statements if you want to control the computation graph that is being built. And use JAX-level statements like jax.numpy.where if you want to branch on values inside the computation graph. These two different operations perform two different things, by design.

As another example: if you have any print statements, then these will print things out whilst building the computation graph (i.e. running your Python code). If you want a print statement to happen inside the computation graph, use jax.debug.print. Once again, these two operations are designed to do different things.

Remember, the Python code is just specifying a way to build a computation graph.
(Fun fact, as of PyTorch 2.0, then torch.compile now does something very similar under-the-hood.)

3. Remember vmap

Remember that JAX has jax.vmap! This simplifies a lot of things.

For example, instead of writing neural network layers that accept an array of shape (batch, channels), you can write an easier-to-understand layer that accepts an array of shape (channels,) – and then vmap it.

4. PyTrees

When you do something like this:

x = [array1, array2]

def f(x):


then JAX knows to unpack the x argument (which is a list) in order to find the arrays array1 and array2. This is how it can replace arrays with tracers when JIT compiling – and this unpacking is also how JAX can find the arrays to create gradients for when using jax.grad, or the arrays to vectorise when using jax.vmap, and so on.

The objects that JAX knows how to unpack are lists, tuples, dictionaries, user-registered custom nodes (we’ll come back to these), and any arbitrarily-nested collection of these. Overall, these are called “PyTrees”. It’s typical to represent a model as some PyTree of its parameters.

It’s almost always a mistake to use raw Python classes with JAX. These aren’t PyTrees, so JAX can’t know how you want to handle these unless you tell it how to. As such, we generally register classes as custom pytree nodes. (One way to do this is by subclassing equinox.Module, see below.)

This is directly analogous to torch.nn.Module, and how you must use self.foo = torch.nn.ModuleList(...) rather than self.foo = [...]. You can’t use raw Python classes with PyTorch either. The only important difference is that PyTorch Modules are treated as directed acyclic graphs (=the same Module can appear in multiple places), whilst JAX PyTrees are treated as, well, trees (=multiple appearances of the same object are treated as independent copies).

To get started with the practicalities of building your models as PyTrees, you might like to check out the Equinox libary on GitHub. Note that in general, the leaves of a PyTree can be things other than JAX arrays: you might also have activation functions, boolean flags, etc. Each Equinox model is a single big PyTree containing all of the above.

5. Gotcha: in-place updates

JAX uses a functional style. Instead of performing in-place updates like you might with PyTorch or NumPy:

some_array[2:5] = 0

you would write the following:

some_array = some_array.at[2:5].set(0)

Be careful if you’re using this together with autodifferentiation! The above might actually copy all of some_array, then apply the update – and so not make the update in-place. The reason is that sometimes, an in-place update might would overwrite information needed for the backward pass. To avoid this, JAX makes a copy… but this can be silent performance footgun. (The compiler does do some analysis to try and avoid this if possible, but it isn’t perfect.)

You may sometimes prefer to write e.g.

some_array = jax.numpy.concatenate([some_array[:2], jax.numpy.zeros(3), some_array[5:]])


(As a comparison: PyTorch does allow you to overwrite information, but will throw an error on the backward pass if if thinks you’ve overwritten needed information. And this will sometimes throw false positives: it also uses an imperfect analysis tracked per-array, with a tensor._version counter.)

6. Randomness

JAX uses an explicit approach to generating randomn numbers: you have to provide a PRNG key. So instead of torch.randn(shape), you have jax.random.normal(key, shape). As such e.g. initialising the weights of a linear layer would look like:

def make_linear(in_size, out_size, key):
    wkey, bkey = jax.random.split(key, 2)
    weight = jax.random.normal(wkey, (out_size, in_size))
    bias = jax.random.normal(bkey, (out_size,))
    return (weight, bias)

seed = 1234
key = jax.random.key(seed)
model_key, some_other_key, another_key = jax.random.split(key, 3)
parameters = make_linear(6, 32, model_key)
# use the other two keys for whatever other purposes you have

(Of course in practice you’d use a neural network library, and not write this all out yourself.)

Don’t use the same random key twice unless you explicitly want to generate the same random numbers twice.

This is actually a JAX superpower! Explicitly threading the random state like this means you get trivially reproducible behaviour. (Something you have to work very hard for normally.)

7. Gotcha: call expensive functions as few times as possible

Something like this:

def f(x):
    for i in range(100):
        x = my_complicated_function(x)
    return x

can take a long time to compile.

The reason is that when JAX traces through this, it can’t see the for loop. (All it sees are the operations applied to the tracers!) As a result you’ll get 100 independent copies of my_complicated_function, which all get compiled separately.

If you ever get long compile times, then 99% of the time you’ve probably made some version of this mistake.

(In this case, a jax.lax.scan is probably what you want. Likewise it’s usually also preferable to rewrite even simple stuff like

x2 = f(x1)
x3 = f(x2)

as a little length-2 scan.)

8. Gotcha: vmap(jax.lax.cond) evaluates both branches

JAX has jax.numpy.where(pred, a, b) to do an if statement between two arrays a and b. This is like NumPy, and both arrays a and b need to have been evaluated.

What if you don’t want to evaluate both branches? (Maybe they’re expensive to compute.) For this there is jax.lax.cond(pred, if_fn, else_fn), which is the runtime equivalent of a Python if statement.

There is one “gotcha” here. If your computation is batched due to a jax.vmap, then both if_fn and else_fn will be evaluated. Under-the-hood it is rewritten into a jax.numpy.where(batch_pred, if_fn(...), else_fn(...)). After all, some batch elements might need one branch, and some batch elements might need the other.

So this makes sense for JAX’s programming model, but it can also be a bit surprising. E.g. if if_fn sometimes produces an infinite loop for some inputs, and the jax.lax.cond is to guard against that, then that infinite loop will still be caught when vmap’d! (In this case you could fix this by using another jax.lax.cond to replace the bad input with a dummy-but-safe input before the loop.)

9. JAX libraries

JAX doesn’t try to bundle everything in one place. JAX itself only provides the underlying numerical operations of addition, matrix multiplication etc., and the transforms like autodiff, jit-compilation, etc. This is roughly equivalent to everything in torch.*. Everything else is handled through libraries built on top of JAX.

An equivalent of torch.optim is the Optax library.

An equivalent of torch.nn is the Equinox library. (There are a couple of other neural network libraries as well – namely Flax and Haiku. All three of these libraries are used at Google, and maintained by Googlers! I’m linking to the one I’m most involved with, which is also the one that is most similar to PyTorch.)

Use jaxtyping for type annotations for the shape and dtype of arrays. These look like

from jaxtyping import Array, Float

def some_function(image: Float[Array, "batch channels height width"]):

Next steps

That’s it! If you know PyTorch and you’ve read this guide then you’re ready to get started using JAX.

You can find the JAX documentation here. Here is a CNN on MNIST walkthrough that introduces both JAX and Equinox, and here’s another introductory RNN example.