JAX-AMG: A GPU-Accelerated Differentiable Sparse Linear Solver Library for JAX

2026-06-08Mathematical Software

Mathematical Software
AI summary

The authors created JAX-AMG, a tool that helps solve large math problems from physics simulations more efficiently on GPUs. It combines a powerful Nvidia solver with JAX, enabling easy use of advanced solving methods, automatic differentiation, and running on multiple GPUs at once. This makes it useful for tasks where you need to solve these problems many times, like optimizing designs or understanding complex systems. Their tool fits well into modern scientific computing workflows that rely on machine learning and simulations.

Sparse linear systemsPartial differential equations (PDE)Algebraic multigrid (AMG)JAXAutomatic differentiation (AD)GPU accelerationKrylov methodsMPI (Message Passing Interface)PDE-constrained optimizationScientific machine learning
Authors
Yi Liu, Xiantao Fan, Jian-Xun Wang
Abstract
Sparse linear systems from PDE discretizations are central to scientific computing, yet no existing JAX-ecosystem solver simultaneously provides GPU-accelerated algebraic multigrid (AMG), automatic differentiation (AD), and distributed multi-GPU execution. JAX-AMG fills this gap by wrapping the Nvidia AmgX solver suite as a native JAX primitive, exposing AMG and Krylov methods with configurable preconditioners through a unified interface compatible with JIT compilation, reverse-mode AD via adjoint methods, batched solves, and MPI-based distributed execution. Solver caching amortizes setup costs across repeated solves, making JAX-AMG practical for PDE-constrained optimization and inverse problems. The result is a robust, scalable sparse linear algebra layer that integrates seamlessly into differentiable simulation and scientific machine learning pipelines.