DEV Community

Aditya Mehra
Aditya Mehra

Posted on

I Built a Diagnostic Toolkit for PyTorch Because I Was Tired of Guessing Why Models Fail

Every time a PyTorch model refuses to learn, the debugging process looks the same:

  1. Stare at the loss curve
  2. Wonder if gradients are flowing
  3. Add print statements everywhere
  4. Delete them all when it works
  5. Repeat next week

After 17 years in distributed systems and SRE, I know this pattern — it is monitoring by vibes. In production infrastructure, we would never accept "the service seems slow" as a diagnostic. We measure. We trace. We verify.

So I built torchdiag — five diagnostic commands that answer the actual questions.

Install

pip install torchdiag
Enter fullscreen mode Exit fullscreen mode

GitHub logo AddyM / torchdiag

PyTorch model health diagnostics — gradient checks, dead neuron detection, training verification. Built from an SRE perspective.

torchdiag

CI PyPI PyTorch License: MIT Python 3.8+

PyTorch model health diagnostics — built from an SRE perspective.

Stop guessing why your model isn't learning. torchdiag gives you five diagnostic commands that answer the questions that matter: Are gradients flowing? Are neurons alive? Did the optimizer actually update weights?

Installation

pip install torchdiag
Enter fullscreen mode Exit fullscreen mode

Quick Start

import torch
import torch.nn as nn
import torchdiag
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
)

# 1. Model overview
torchdiag.summary(model)

# 2. Check for dead neurons
x = torch.randn(100, 784)
torchdiag.check_dead_neurons(model, x)

# 3. Verify a full training step works
torchdiag.verify_step(
    model,
    torch.optim.Adam
Enter fullscreen mode Exit fullscreen mode

1. What does my model actually look like?

import torchdiag
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
)

torchdiag.summary(model)
Enter fullscreen mode Exit fullscreen mode

Prints parameter count per layer, total/trainable/frozen breakdown, memory footprint, device placement, and dtype distribution. Flags frozen parameters, split-device models, and dtype mismatches.

2. Are gradients flowing?

loss = nn.CrossEntropyLoss()(model(x), target)
loss.backward()

torchdiag.check_gradients(model)
Enter fullscreen mode Exit fullscreen mode

Reports gradient mean, max, and min per layer. Flags vanishing gradients (max below 1e-7), exploding gradients (max above 100), and disconnected parameters (None gradients).

3. Are neurons alive?

x = torch.randn(100, 784)
torchdiag.check_dead_neurons(model, x)
Enter fullscreen mode Exit fullscreen mode

A dead ReLU neuron outputs zero for every input. Its gradient is permanently zero. It will never learn again. This command tells you how many you have and where. Flags critical layers with more than 50% dead neurons.

4. Does one training step actually work?

torchdiag.verify_step(
    model,
    torch.optim.Adam(model.parameters()),
    nn.CrossEntropyLoss(),
    torch.randn(32, 784),
    torch.randint(0, 10, (32,)),
)
Enter fullscreen mode Exit fullscreen mode

Runs one complete training step — forward, loss, backward, optimizer step — and verifies each stage works. Confirms output shape is correct, loss is finite, gradients are computed, and parameters actually change.

Run this before your training loop. If something is broken, you will know in 1 step instead of 100 epochs.

5. How much memory am I using?

torchdiag.memory_report()
Enter fullscreen mode Exit fullscreen mode

Reports CPU RSS, GPU allocated/cached/peak per device, and MPS memory on Apple Silicon. Flags when GPU utilization exceeds 90%.

Why I Built This

I spent 11 years at VMware working on distributed systems observability. The first thing you learn in SRE: never trust a system you cannot measure.

PyTorch models are systems. They have inputs, internal state, and outputs. When they fail, they fail silently — the loss just stays flat. No error. No exception. Just a number that does not move.

torchdiag makes the internal state visible. Five commands. No configuration. No dependencies beyond PyTorch.

PyPI: pypi.org/project/torchdiag
GitHub: github.com/AddyM/torchdiag
CI: Tests pass across Python 3.9 to 3.12

Contributions welcome. If you have a debugging pattern you use repeatedly, open an issue — it probably belongs in the toolkit.

Top comments (1)

Collapse
 
harjjotsinghh profile image
Harjot Singh

Training fails silently is the exact phrase, and it's why this kind of toolkit is worth more than it looks: NaN losses, vanishing gradients, and shape mismatches that only surface deep in a run all share one property, they don't throw when they happen, they corrupt quietly and you find out an hour and a lot of compute later. Adding assertions and early checks is the same move that makes any complex system debuggable: turn a silent corruption into a loud, located failure at the moment it occurs, so you stop guessing and start seeing. Fail-fast beats fail-silent, and a check that fires at step 3 saves you from forensically reconstructing what went wrong at step 3000. The thing I like is that assertions also double as documentation of your assumptions, this tensor should be this shape, this loss should be finite, which is exactly the stuff that's implicit and therefore silently violated. It generalizes well beyond training too, the same instinct (validate at the boundary, make the invariant explicit, crash loudly when it breaks) is what separates a debuggable AI system from a black box you pray at, and it's core to how I think about reliability in Moonshift. Of the failure modes you cover, which one was hardest to catch early, the NaN propagation or the silent shape broadcasting that technically runs but trains the wrong thing?