Guide
How it works
nantrack uses Functors.fmap_with_path to walk the model tree. Each leaf layer that satisfies trackable gets wrapped in a NaNCheck node.
The wrapper does two things:
Forward pass — before calling the wrapped layer it checks
any(hasnan, args). After the call it checkshasnan(y). If either fires, aDomainErroris thrown with the layer'sKeyPath.Backward pass — a custom
ChainRulesCore.rruleintercepts automatic differentiation. It checks the incoming gradientΔfor NaN, throwing with the same path information.
Because NaNCheck is parametric (NaNCheck{P,L}), it is fully type-stable and works transparently on CPU and GPU arrays.
Composite models
Composite structs (anything registered with Flux.@layer) are not wrapped themselves — fmap recurses into them so only the leaf computational layers get checked. This means you define your model exactly as before:
struct Encoder
embedding::Embedding
mha::MultiHeadAttention
norm::LayerNorm
end
Flux.@layer Encoder
function (e::Encoder)(x; mask=nothing)
z = e.embedding(x)
z = e.norm(first(e.mha(z, mask=mask)) + z)
return z
end
model = Encoder(Embedding(100 => 64), MultiHeadAttention(64 => 32 => 64, nheads=4), LayerNorm(64))
tracked = nantrack(model) # Dense, Embedding, LayerNorm inside get wrappedCustom leaf layers
By default the following are tracked: Dense, Embedding, LayerNorm, Scale, and Conv.
To add your own leaf layer, define a single dispatch method:
struct MyAttention
W::Dense
end
Flux.@layer MyAttention
(m::MyAttention)(x) = softmax(m.W(x))
# Option A: track MyAttention as a whole
NaNTracker.trackable(::KeyPath, ::MyAttention) = true
# Option B: leave it as false (default) so fmap recurses into W,
# and Dense is already tracked.Unwrapping
Call nanuntrack to strip all NaNCheck nodes and restore the original model. This is useful when you are done debugging and want to deploy without overhead:
clean = nanuntrack(tracked)GPU support
NaNCheck does not constrain element types or array types, so it works transparently with CUDA arrays:
using CUDA
gpu_model = nantrack(model) |> gpuStats tracking
When a NaN is detected, DomainError tells you which layer but not how values grew. Enable stats tracking to record norm and maxabs at every checked layer — both forward and backward:
enable_stats!() # ring buffer of 1000 entries (default)
loss, grads = Flux.withgradient(tracked) do m
sum(m(x))
end
# Inspect magnitudes (filter to specific layers)
dump_stats(path_contains="attention")
clear_stats!() # reset for next step
disable_stats!() # turn off when done (zero overhead)When a NaN is detected with stats enabled, the recent trajectory is dumped to stderr before the DomainError is thrown — showing the explosion path leading up to the failure.
Note: On GPU, stats recording triggers scalar transfers (sync points) at every checked layer. Use for debugging only.