PyTorch
The dominant deep learning framework. Dynamic graphs, great debugging, and the de facto standard for research and most production ML.
Category
Deep Learning Frameworks
Difficulty
Intermediate
When to use
Any deep learning work — training from scratch, fine-tuning, research, or deploying custom models.
When not to use
Classical tabular ML (use XGBoost or scikit-learn) or very TPU-centric research (JAX is often a better fit).
Alternatives
JAX TensorFlow MLX
At a glance
| Field | Value |
|---|---|
| Category | Deep learning framework |
| Difficulty | Intermediate |
| When to use | Any neural network training or custom inference |
| When not to use | Tabular ML; TPU-native research |
| Alternatives | JAX, TensorFlow, MLX |
What it is
PyTorch is a Python tensor library with GPU acceleration and a tape-based autograd system. You write normal Python, autograd records the ops, and loss.backward() computes gradients. Since torch.compile (PyTorch 2.x), you can also opt into graph-mode optimizations without rewriting your code.
When we reach for it at Ephizen
- Training any custom model — classification, embeddings, small language models.
- Fine-tuning HuggingFace models (they ship as PyTorch by default).
- Writing custom loss functions or sampling procedures that would be painful in a static graph.
- Running inference behind FastAPI for models we control.
Getting started
import torch, torch.nn as nn
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 1)).cuda()
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
for x, y in loader:
x, y = x.cuda(), y.cuda()
pred = model(x)
loss = loss_fn(pred, y)
loss.backward(); opt.step(); opt.zero_grad()
Gotchas
- Forgetting
model.eval()andtorch.inference_mode()at inference leaves dropout on and wastes memory. - CUDA OOM messages are often misleading — the real culprit is a tensor you didn’t
.detach(). - Pin
torch,torchvision, and CUDA versions. Mismatches cause cryptic errors. - For distributed training, reach for Lightning, Accelerate, or DeepSpeed before writing DDP from scratch.
Related tools
- scikit-learnThe classical ML library for Python. Consistent API over dozens of algorithms for regression, classification, clustering, and preprocessing.
- TensorFlowGoogle's deep learning framework. Still widely deployed in production, especially via TF Serving, TFLite, and TF.js.
- XGBoostHigh-performance gradient-boosted decision tree library. The default strong baseline for tabular data.