NaNTracker.jl
Lightweight NaN detection for Flux.jl models.
NaNTracker wraps leaf layers to check forward inputs, forward outputs, and incoming gradients — throwing a DomainError with the exact layer path at the first NaN.
Installation
using Pkg
Pkg.add(url="https://github.com/mashu/NaNTracker.jl")Quick start
using NaNTracker, Flux
model = Chain(Dense(10 => 20, relu), Dense(20 => 5))
# Wrap — every forward and backward pass is checked for NaN
tracked = nantrack(model)
x = randn(Float32, 10, 8)
loss, grads = Flux.withgradient(tracked) do m
sum(m(x))
end
# Remove tracking when done
clean = nanuntrack(tracked)If a NaN appears anywhere in the computation, you get:
DomainError with KeyPath(:layers, 2):
NaN in forward output