jNO: A JAX Library for Neural Operator and Foundation Model Training

2026-05-11Machine Learning

Machine Learning
AI summary

The authors present jNO, a software library built in JAX that helps create and train neural operator models. It uses a single symbolic language for different tasks like data-driven learning and physics-based training, which makes switching between these methods easier without changing much code. jNO also allows combining multiple models and fine-tuning training details in one place. The library supports advanced workflows for solving equations related to physics and other fields.

JAXneural operatorsphysics-informed learningoperator regressionPDE-constrained trainingsymbolic languagemesh-aware residualsoptimization pipelinefoundation modelshyperparameter tuning
Authors
Leon Armbruster, Rathan Ramesh, Georg Kruse, Christopher Straub
Abstract
jNO (jax Neural Operators) is a JAX-native library for neural operators and foundation models with unified support for both data-driven and physics-informed training. Its core design is a tracing system in which domains, model calls, residuals, supervised losses, and diagnostics are written in one symbolic language and compiled into one optimization pipeline. This allows users to move between operator regression, mesh-aware residual evaluation, and PDE-constrained training without restructuring the surrounding code. jNO also supports multi-model compositions, fine-grained control at parameter level (model, optimizer, and learning rate), hyperparameter tuning, and JAX-native workflows for translated PDE foundation-model families. The source repository is available at https://github.com/FhG-IISB/jNO.