No more shape errors! Type annotations for the shape+dtype of tensors/arrays.

< Back to "Thoughts" | Posted on December 18, 2023

TL;DR: you can explicitly use type annotations of the form

def f(x: Float[Tensor, "channels"],
      y: Float[Tensor, "channels"]):
    ...

to

  1. specify the shape+dtype of tensors/arrays;
  2. declare that these shapes are consistent across multiple arguments;
  3. use runtime type-checking to enforce that these are correct.

See the (now quite popular!) jaxtyping library on GitHub. And note that the name is now historical – it also supports PyTorch/TensorFlow/NumPy, and has no JAX dependency.

For those who’ve seen jaxtyping before, I’m writing this blog post now as we just had a new release upgrading the runtime type-checking substantially (adding some carefully-constructed error messages – see below), so this felt like a great time to say a few words about it.

The story so far

Right now, a lot of folks write code that looks like this:

from torch import Tensor

def f(x: Tensor, y: Tensor):
    # x and y must be one dimensional
    # x and y must have the same size
    ...

with extra information about the arguments encoded in a comment.

If you’ve ever written such a program in PyTorch or JAX or TensorFlow or NumPy, then at some point you’ve probably made a “shape error”. A tensor or array had the wrong shape, but the code silently ran anyway. Probably due to broadcasting or something. And because this happens silently, this sometimes leads to frustrating downstream results like “my model is failing to train and I don’t know why”.

At the very least, if you come back to the code after 6 months… you’ll have found it much harder to reason about what it’s doing. (Why is that .transpose(1, 2) in there again?)

Type annotations and runtime type checking

jaxtyping is designed to fix this. By instead writing the following:

from jaxtyping import Float
from torch import Tensor

def f(x: Float[Tensor, "channels"],
      y: Float[Tensor, "channels"]):
    ...

we can explicitly encode that

  1. the tensors are one dimensional, specifically that they have shape (channels,). (And for example, a two-dimensional shape of (size1, size2) would be annotated as Float[Tensor, "size1 size2"].)
  2. the tensors have the same size as each other – the "channels" annotation is used across both arguments.
  3. along the way, we also specify the dtype – in this case that it’s floating-point. There is also Int, Bool etc. available as well.

Now, even better, we can add some runtime type-checking:

# https://github.com/beartype/beartype
from beartype import beartype
from jaxtyping import Float, jaxtyped
from torch import Tensor

@jaxtyped(typechecker=beartype)
def f(x: Float[Tensor, "channels"],
      y: Float[Tensor, "channels"]):
    ...

and in doing so can explicitly validate at runtime that we’ve done the correct thing:

from torch import zeros

f(zeros(3), zeros(4))

# jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of f.
# The problem arose whilst typechecking parameter 'y'.
# Actual value: f32[4](torch)
# Expected type: <class 'Float[Tensor, 'channels']'>.
# ----------------------
# Called with parameters: {'x': f32[3](torch), 'y': f32[4](torch)}
# Parameter annotations: (x: Float[Tensor, 'channels'], y: Float[Tensor, 'channels']).
# The current values for each jaxtyping axis annotation are as follows.
# channels=3

in which we get a helpful error message pointing out that our second argument y seems to have the wrong shape – because we’ve already got channels=3 from our first argument x.

Pretty useful!

The above is actually a fairly simple example of jaxtyping. But for the power user, quite a few different shapes can be specified. For example,

  1. variadic numbers of dimensions can be allowed with *, e.g. Float[Tensor, "*batch channels"] allows arbitrarily many batch dimensions before the channel dimension.
  2. fixed-size dimensions can just have the number e.g. 3 used directly.
  3. math is allowed:
    def remove_last(x: Float[Tensor, "dim"]
                  ) -> Float[Tensor, "dim-1"]
    

    (The math is just evaluated as mini Python program using the values of the dimensions found so far.)

  4. broadcasting can be allowed with #,
  5. documentation-only names can be added to the specifier by using an =, e.g. Float[Array, "rows=4 cols=3"].

This produces a flexible DSL for specifying basically every way the shapes of your tensors/arrays might change through your program.

JAX vs PyTorch

jaxtyping works with both. Use whichever tech stack you prefer!

For those using JAX, note that the runtime type-checking happens during JIT-tracing, and so it adds no overhead at runtime. You might also like to use JAX’s environment variable JAX_NUMPY_RANK_PROMOTION=raise, which disables broadcasting.

For those using PyTorch, you may have seen my previous library “TorchTyping”. If you have, I strongly advise switching to jaxtyping instead – it’s much (much) easier to use, and doesn’t do any awful monkey-patching of typeguard under-the-hood. (Oops.) TorchTyping was the prototype version; jaxtyping is the polished version.

jaxtyping is available on GitHub here. I’ve heard on the grapevine that this is now pretty widely-used across quite a few companies. :)

Appendix

1. Beyond jaxtyping

I think jaxtyping is probably close-to-optimal for solving this problem in Python and with existing tech stacks (JAX, PyTorch, …). But FWIW, this kind of shape-checking could theoretically be done much more carefully if the language or framework was built from the ground up to support this.

For example, named tensors move away from having tuple-shapes like (64, 3, 32, 32) in favour of dictionary-like shapes like {"batch": 64, "channels": 3, "height": 3, "width": 3}. As another example, Dex encodes the index set (the allowable values for i when writing array[i]) directly into the type of the array, which allows for distinguishing dimensions that are similarly-sized but different-in-meaning, or to index with things other than integers (e.g. named-tensor-style dictionaries).

2. Internals

To give a peak under the hood, how does jaxtyping work?

It’s honestly pretty simple. Every time you do an isinstance check, e.g.

isinstance(some_tensor, Float[Tensor, "batch channels height width"])

then an internal dictionary of sizes tracking things like batch=64, channels=3 are checked and updated. If a size is inconsistent with one already stored, then the isinstance check returns False. A new dictionary is used for every jaxtyped decorator.

Then a runtime type-checker is used to call isinstance for every argument.

3. JAX ecosystem

JAX ecosystem

Whilst you’re here, some of you might be interested to know about the rest of the JAX ecosystem:

  1. Equinox, a PyTorch-like neural network library;
  2. Diffrax, high-performance numerical differential equation solvers;
  3. Levanter for training foundation models;

Collectively these form the foundation of a state-of-the-art machine learning stack, including both neural networks and classical scientific problems, a kind of “differentiable GPU-capable scipy”.

Broadly speaking these are faster than PyTorch, offer autodiff over scipy, and have fewer correctness issues than Julia. But downside, JAX is also slightly harder to use.