DEV Community

Cover image for Turning Any Model into an XAI-Ready Model: Formats and Gradient Flow
Tova A
Tova A

Posted on

Turning Any Model into an XAI-Ready Model: Formats and Gradient Flow

This post is based on work done during a joint Applied Materials and Extra-Tech bootcamp, where I built an XAI platform.
I’d like to thank Shmuel Fine (team leader) and Odeliah Movadat (mentor) for their guidance and support throughout the project.


Why Gradient-Based XAI Sometimes “Works” but Tells You Nothing

Gradient-based explainability methods (Grad-CAM, Guided Backprop, Integrated Gradients, etc.) are everywhere.
In tutorials, you call a function, get a pretty heatmap, and move on.
In a real project, it’s different.
I was building an internal XAI platform that needed to work across:

  • Different ML frameworks (PyTorch, TensorFlow)
  • Different vision tasks (classification, segmentation, regression, and custom industrial models)
  • Different stored model formats and exports collected over time

In theory, any gradient-based method should “just work” on top of these models.
In practice, once we started running them, things got messy. Sometimes we got blank or obviously wrong heatmaps with no warning. Other times it failed loudly with:

RuntimeError: element 0 of variables does not require grad and does not have a grad_fn

The forward pass was correct, the predictions made sense — but the value we were backpropagating from was no longer connected to the gradient graph.

That’s when it became clear: the main problem wasn’t any specific XAI algorithm, but the combination of model formats, conversions, and gradient flow.
In other words, not every model file we could load was actually explainable or measurable.

What an XAI-Ready Model Actually Needs

Very quickly, here’s the minimum a model needs in order to produce meaningful, measurable gradient-based explanations:

  • Gradients that actually exist When we choose a scalar score (for example, a class logit) and call backward(), the gradient must flow back to what we care about – either the input image or some internal layer. If that path is broken, the explainer can still return a heatmap-shaped tensor, but it’s not telling us anything real.
  • A score that stays inside the graph The value we backpropagate from has to be a tensor that’s still part of the computation graph. If it was turned into a Python number, passed through NumPy, or detached along the way, we’ve already lost the information XAI needs.
  • Access to internal features for CAM-style methods For methods like Grad-CAM, we also need a way to read activations and gradients from a chosen internal layer – but that comes after the basic gradient path is in place.

This post is about how to work with real-world models without breaking these requirements.

How Common Formats Behaved in Our Platform

Once we knew what an XAI-ready model needs, we looked at what we actually had: ONNX exports, TorchScript files, and some legacy TensorFlow models. All of them were fine for inference. For gradient-based XAI, the picture was very different.

This complexity really shows up when you’re not just loading your own training code, but building a platform that has to accept unknown model architectures from different teams.
If you control the architecture, you can usually rebuild a clean eager model and just load a state_dict. If all you get is a stored artifact (ONNX, TorchScript, legacy TF graph), then the format itself decides how much structure and gradient information you still have. That’s exactly the situation we were in.

ONNX – Great for Inference, Not for Gradients

We had models already deployed as ONNX. It was tempting to reuse them for Grad-CAM and Integrated Gradients. In practice, ONNX runtimes are optimised for forward passes, not for autograd:

  • You get fast, correct predictions.
  • You don’t get a PyTorch/TF-style gradient graph or easy hooks into internal layers.

So ONNX became our conclusion: perfect for deployment, but not a reliable source for gradient-based explanations or metrics. For XAI, we need the original framework model, not just the ONNX file.

TorchScript – fine for simple gradients, fragile for CAM-style methods

In our experience, TorchScript models do support gradients: if the export wasn’t heavily frozen or over-optimised, we can reliably backpropagate from a scalar score to the input and obtain meaningful gradient-based heatmaps.

The problems appear when CAM-style methods require access to internal convolutional features. Some TorchScript exports fuse layers, inline modules, or alter module boundaries, so the convolutional blocks we want to hook are no longer explicitly exposed. In these cases, forward and backward hooks become fragile, and optimisation steps can effectively make internal activations inaccessible even though gradients still exist.

Because of this, we treat TorchScript as acceptable for gradient-w.r.t-input methods, but for CAM-style explainers we require the original eager nn.Module, where internal layers remain cleanly and reliably accessible.

Native PyTorch / TF – Our XAI-Ready Baseline

After this, we decided
that for explainability the “ground truth” formats are:

  • PyTorch nn.Module in eager mode
  • TF2/Keras models or SavedModels that work cleanly with GradientTape

All other artifacts (ONNX, unknown TorchScript, legacy TF graphs) are welcome for inference, but we don’t assume they are explainable until proven otherwise.

We also learned that simply “converting back” from a non-differentiation-friendly format doesn’t magically fix things.
You can end up with a PyTorch nn.Module or a TF2 SavedModel that looks clean, but was rebuilt from ONNX or an old TF1 graph using a script full of .numpy() calls and manual tensor operations.
On paper the format is now “good”, but the gradient path is still broken.
For a deeper dive into how we converted legacy models without breaking gradients, see the companion post.

Model format Gradient support CAM compatibility Recommended usage
ONNX ❌ No autograd graph ❌ Not supported Inference / deployment only
TorchScript ✅ Full (in correct loading) ⚠️ Fragile (layers may be fused or hidden) Simple gradient methods when eager model is unavailable
Native PyTorch (eager) ✅ Full ✅ Full Gradient-based XAI and quantitative metrics
Native TF2 / Keras ✅ Full (GradientTape) ✅ Full Gradient-based XAI and quantitative metrics
Converted models (from ONNX / TF1) ⚠️ Depends on conversion ⚠️ Works if gradients are supported Treat as inference-only unless gradients are explicitly verified

Don’t Break Gradients in Your Own Code

Even with an XAI-ready format, it’s still easy to kill gradients in the parts we control: preprocessing, adapters, and forward passes. We saw a few recurring “self-inflicted” problems:

  • Wrapping the whole explanation call in torch.no_grad()
  • Calling .detach() on tensors that are still needed for the XAI score
  • Converting tensors to NumPy (.cpu().numpy()) and then using those values to compute the score
  • Using .item() on logits and doing the rest of the logic in pure Python

All of these are harmless for plain inference, but they quietly break the gradient path that explanations rely on.
Our rule of thumb became:

The main path from input → model → XAI score must stay inside the framework’s autograd system.

If we really need to log or serialize something, we do it after we’ve computed the score we’ll backpropagate from.

Bonus: A Tiny Sanity Check for New Models/Adapters

Whenever we plug a new model or adapter into the platform, we run a quick check:

  1. Load the model in an XAI-ready format (eager PyTorch / TF2).
  2. Pick a simple score (for example, one class logit).
  3. Call backward.
  4. Verify that gradients at the input (or at the adapter boundary) are non-zero.

In PyTorch, this is just a few lines of code. If this fails with does not have a grad_fn or always gives zero gradients, we usually don’t look at the explainer first – we look at the model format or the forward path we’ve built around it.

import torch

model = torch.load("model.pt", map_location = "cpu", weights_only = False)
model.eval()

x = sample_image.to("cpu") 
x.requires_grad_(True)

logits = model(x)
score = logits[:, target_class].mean()

model.zero_grad()
score.backward()

print("Grad norm on input:", x.grad.norm().item())
Enter fullscreen mode Exit fullscreen mode

If this prints a reasonable, non-zero gradient norm, the model is at least technically explainable for gradient-w.r.t-input methods.

In practice, this became our first filter: If the model is fully differentiable, we keep going with gradient-based explanations and metrics.
If not, we still allow inference – but we deliberately treat that model as inference-only, not explainable.

Top comments (0)