JAX vs Julia (vs PyTorch)
< Back to "Thoughts" | Posted on May 3, 2022A while ago there was an interesting thread on the Julia Discourse about the “state of machine learning in Julia”. I posted a response discussing the differences between Julia and Python (both JAX and PyTorch), and it seemed to be really well received!
Since then this topic seems to keep coming up, so I thought I’d tidy up that post and put it somewhere I could link to easily. Rather than telling all the people who ask for my opinion to go searching through the Julia Discourse until they find that one post… :D
To my mind JAX and Julia are unquestionably the current state-of-the-art frameworks for autodifferentiation, scientific computing, and ML computing. So let’s dig into the differences.
What about PyTorch?
I used PyTorch throughout my whole PhD. But about a year ago I switched to JAX and haven’t looked back. At least for my use cases it’s faster and more feature-complete. I’ll discuss PyTorch a little bit in this post, but it won’t be the focus.
TL;DR? Julia is amazing, but I’m using JAX. Read on to find out why.
Similarities
First of all:
are really – I mean really – similar.
Both represent models (such as neural networks) in the same way, as a tree of modules and parameters. What Equinox calls equinox.Module
is what Flux calls Flux.@functor
. What JAX calls jax.tree_map
is what Julia calls Flux.fmap
.
Both of them are based around just-in-time (JIT) compilers. That is, your code is compiled and optimised down to efficient machine code the first time it is run.
As such, both JAX and Julia have excellent speed at runtime. This is particularly noticeable when compared to PyTorch for scientific computing. Differential equation solvers are quite complicated, and here the overhead of PyTorch’s use of the Python interpreter starts to bite. If you want to know more, check out Horace He’s excellent post on whether your operation is limited by compute, memory, or overhead: Making Deep Learning Go Brrrr From First Principles.
Both JAX and Julia ubiquitously perform program transforms via homoiconicity. For example, backpropagation is performed by parsing your code in reverse.
Homoiconicity
If you haven’t heard of homoiconicity before: this is the property of being able to represent code as data. For example, you could represent your code as a list of strings; one per line.
There are better ways of representing code than that though, of course. JAX has jaxprs; Julia has macros and metaprogramming.
Either way, this allows you to write code that modifies other code. Once again using backpropagation as an example: rewriting your code to be evaluated in reverse order, so as to compute a gradient.
Where does Julia shine / where is JAX lacklustre?
Compilation speed.
Julia is substantially faster than JAX on this front. JAX is a lovely framework, but a substantial part of it – the part that computes its program transformations like
jax.vmap
,jax.grad
etc. – is written in Python. These tracing and transformation mechanisms are pretty complicated (more precisely, they have very deep call stacks), and these can simply take quite a long time for the Python interpreter to chug through.Meanwhile, the Julia folks have put in a lot of effort to really speed up compile times, and it shows.
Introspection.
Julia offers tools like
@code_warntype
,@code_native
, which you can use to see how your code is being compiled, what the generated assembly looks like, etc.The closest you can get in JAX is the rather-wordy
print( jax.jit(your_function) .lower(*args, **kwargs) .compile() .runtime_executable() .hlo_modules()[0] .to_string() )
…which in any case only tends to print out an inscrutable mess of XLA. Eurgh.
For the uninitiated: XLA is the backend that JAX translates your code into, to efficiently compile it.
In turn this represents a lack of control over the XLA compiler, as it becomes difficult to verify whether any desired optimisations are really being compiled in. For example, it’s not possible to check whether an array was updated in-place or out-of-place, and very rarely odd performance bugs in the XLA compiler mean that adding dead code can improve runtime performance. (!)
Julia is a programming language; JAX is a DSL.
JAX basically uses Python as a “metaprogramming language” that specifies how to build an XLA program. This means, for example, that we get things like
jax.lax.fori_loop
instead of native syntax forfor
loops. Unless you’re relatively practiced at JAX, it can be a bit tricky to read.
Where does JAX shine / where is Julia lacklustre?
Documentation.
If I want to do the equivalent of
detach
in PyTorch orjax.lax.stop_gradient
in JAX, how should I do that in Julia/Flux?First of all, it’s not in the Flux documentation. Instead it’s in the documentation for a separate component library – Zygote. So you have to check both, and you have to know to check both.
Once you’ve determined which set of documentation you need to look in, there are the entirely separate
Zygote.dropgrad
andZygote.ignore
.What’s the difference? Not specified. Will they sometimes throw mysterious errors? Yes. Do I actually know which to use at this point? Nope.
Meanwhile JAX’s documentation is “fine”. The clear winner here is actually PyTorch, which has much better documentation than either of the others.
Gradient reliability.
I remember all too un-fondly a time in which one of my Julia models was failing to train. I spent multiple months on-and-off trying to get it working, trying every trick I could think of.
Eventually – eventually! – I found the error: Julia/Flux/Zygote was returning incorrect gradients. After having spent so much energy wrestling with points 1 and 2 above, this was the point where I simply gave up. Two hours of development work later, I had the model successfully training… in PyTorch. (And as you can probably guess, these days I’d use JAX.)
Code quality.
Okay – this is the big one. The fundamental problem here is that most Julia packages are written by academics, not professional software developers.
Academic code quality is famously poor, and the Julia ecosystem is no exception. Here’s a small sample of the issues plaguing the ecosystem:
(A) Taking compatibility seriously.
It’s pretty common to see posts on the Julia Discourse saying “XYZ library doesn’t work”, followed by a reply from one of the library maintainers stating something like “This is an upstream bug in the new version a.b.c of the ABC library, which XYZ depends upon. We’ll get a fix pushed ASAP.”
Getting fixes pushed ASAP is great, of course. What’s bad is that the error happened in the first place. In contrast this experience has essentially never cropped up for me as an end user of PyTorch or JAX.
(B) Dead code, unused local variables, …
Even in the major well-known well-respected Julia packages – I’ll avoid naming names – the source code has very obvious cases of unused local variables, dead code branches that can never be reached, etc.
In Python these are things that a linter (or code review!) would catch. And the use of such linters is ubiquitous. (Moreover in something like Rust, the compiler would catch these errors as well.) Meanwhile Julia simply hasn’t reached the same level of professionalism.
(C) Math variable names.
Many Julia APIs look like
Optimiser(η=...)
rather thanOptimiser(learning_rate=...)
. This is a pretty unreadable convention.(D) Inscrutable errors.
When misusing a Julia library, the errors tend to be pretty unhelpful. For example it’s pretty common to get errors about missing methods – essentially, that you called a function with inputs of the wrong type – somewhere internal to a library.
This isn’t a problem with Julia so much as a problem with library authors failing to write out readable errors for common misuses. The result is that the end-user really has to understand the internals of the library to understand what went wrong.
Moreover, a few times I’ve had cases where what I did was theoretically correct, and the error was actually reflective of a bug in the library! (Incidentally, Julia does provide very few tools to library authors to verify the correctness of their work.)
Put simply, the trial-and-error development process in Julia is slow.
Array syntax.
In Julia,
A[1]
andA[1, :]
do different things. The first one will implicitly flatten the array before performing the indexing. This is not functionality I have needed so frequently that it really needed special syntax.When slicing, Julia makes copies by default, and you need to use the
@view
macro to avoid this. This is a silent performance footgun.Julia doesn’t have an equivalent of Python’s
...
, which means something NumPy-likeA[..., 1, :]
instead becomesselectdim(A, ndims(A) - 1, 1)
.Array manipulation is such an important part of ML, but collectively this kind of thing really hinders usability and readability. One gets there in the end, of course, but I find my PyTorch/JAX code to simply be prettier to read and easier to understand.
Conclusion
JAX is great. Whilst I’ve not discussed this in this post, if you ever dig into the internals of JAX then you’ll find that it’s a technical marvel. (See Autodidax for the advanced user.)
The Julia language is great. My criticisms above are primarily of its ML ecosystem (not the language) so with some effort I could see these being fixed. In a few years, perhaps I’ll be using Julia instead.
And yes, also: PyTorch is great. It has a good deployment story, and it has a mature ecosystem. Nonetheless I do find it to be noticeably too slow for the kinds of workloads (mostly based around differential equations) that I tend to put on it.
I maintain libraries for all of PyTorch, JAX, and Julia. But of these, I definitely enjoy maintaining and working with my JAX libraries the most!
And so to conclude on that note: if you’re using JAX, or thinking of switching, then do look at Equinox and Diffrax, which are my libraries for neural networks and differential equation solvers respectively. So far they’ve been pretty well received – about ~1000 GitHub stars between them, at the time of writing this post – so give them a try and let me know what you think.