JAX

Deep Learning

A Python library from Google for high-performance numerical computing with autodiff, JIT compilation, and vectorization via a NumPy-like API.


In one line

NumPy plus autodiff plus XLA JIT plus vmap/pmap — the array library of choice for researchers who want control and speed.

What it actually means

JAX gives you a NumPy-compatible API on top of XLA, Google’s machine-learning compiler. You write pure functions; jit compiles them to fused GPU/TPU kernels, grad gives you gradients, vmap auto-batches over a new axis, and pmap / shard_map parallelizes across devices. It’s the backend for Flax, Equinox, and most of DeepMind’s research code. The mental model is functional: no in-place mutation, explicit PRNG keys, and everything traces through pure functions. That constraint is what makes the compiler so effective.

Why it matters

If you’re doing research on new architectures, training dynamics, or anything that needs fine-grained control over parallelism, JAX gets you closer to the metal than PyTorch without giving up ergonomics. Production ML in industry is still mostly PyTorch, but a lot of frontier-lab work — Gemini, parts of the Claude training stack, AlphaFold — lives in JAX.

Example

import jax, jax.numpy as jnp
@jax.jit
def loss_fn(params, x, y):
    pred = x @ params
    return jnp.mean((pred - y) ** 2)

grad_fn = jax.grad(loss_fn)

You’ll hear it when

  • Reading DeepMind, Google Research, or Anthropic research code.
  • Discussing TPU training and sharding strategies.
  • Comparing PyTorch and JAX for a new research project.
  • Reviewing AlphaFold or similar scientific ML codebases.

Related terms

See also