JAX
A Python library from Google for high-performance numerical computing with autodiff, JIT compilation, and vectorization via a NumPy-like API.
In one line
NumPy plus autodiff plus XLA JIT plus vmap/pmap — the array library of choice for researchers who want control and speed.
What it actually means
JAX gives you a NumPy-compatible API on top of XLA, Google’s machine-learning compiler. You write pure functions; jit compiles them to fused GPU/TPU kernels, grad gives you gradients, vmap auto-batches over a new axis, and pmap / shard_map parallelizes across devices. It’s the backend for Flax, Equinox, and most of DeepMind’s research code. The mental model is functional: no in-place mutation, explicit PRNG keys, and everything traces through pure functions. That constraint is what makes the compiler so effective.
Why it matters
If you’re doing research on new architectures, training dynamics, or anything that needs fine-grained control over parallelism, JAX gets you closer to the metal than PyTorch without giving up ergonomics. Production ML in industry is still mostly PyTorch, but a lot of frontier-lab work — Gemini, parts of the Claude training stack, AlphaFold — lives in JAX.
Example
import jax, jax.numpy as jnp
@jax.jit
def loss_fn(params, x, y):
pred = x @ params
return jnp.mean((pred - y) ** 2)
grad_fn = jax.grad(loss_fn)
You’ll hear it when
- Reading DeepMind, Google Research, or Anthropic research code.
- Discussing TPU training and sharding strategies.
- Comparing PyTorch and JAX for a new research project.
- Reviewing AlphaFold or similar scientific ML codebases.