Related papers: JAXbind: Bind any function to JAX
Among numerical libraries capable of computing gradient descent optimization, JAX stands out by offering more features, accelerated by an intermediate representation known as Jaxpr language. However, editing the Jaxpr code is not directly…
We extend JAX with the capability to automatically differentiate higher-order functions (functionals and operators). By representing functions as a generalization of arrays, we seamlessly use JAX's existing primitive system to implement…
We implement a trust region method on the GPU for nonlinear least squares curve fitting problems using a new deep learning Python library called JAX. Our open source package, JAXFit, works for both unconstrained and constrained curve…
BlackJAX is a library implementing sampling and variational inference algorithms commonly used in Bayesian computation. It is designed for ease of use, speed, and modularity by taking a functional approach to the algorithms' implementation.…
Genetic programming is an optimization algorithm inspired by evolution which automatically evolves the structure of interpretable computer programs. The fitness evaluation in genetic programming suffers from high computational requirements,…
In robot control, planning, and learning, there is a need for rigid-body dynamics libraries that are highly performant, easy to use, and compatible with CPUs and accelerators. While existing libraries often excel at either low-latency CPU…
jinns is an open-source Python library for physics-informed neural networks, built to tackle both forward and inverse problems, as well as meta-model learning. Rooted in the JAX ecosystem, it provides a versatile framework for efficiently…
We introduce JAX MD, a software package for performing differentiable physics simulations with a focus on molecular dynamics. JAX MD includes a number of physics simulation environments, as well as interaction potentials and neural networks…
Context. Inferring spectral parameters from X-ray data is one of the cornerstones of high-energy astrophysics, and is achieved using software stacks that have been developed over the last twenty years and more. However, as models get more…
We introduce JPC, a JAX library for training neural networks with Predictive Coding. JPC provides a simple, fast and flexible interface to train a variety of PC networks (PCNs) including discriminative, generative and hybrid models. Unlike…
Time-series forecasting is central to many scientific and industrial domains, such as energy systems, climate modeling, finance, and retail. While forecasting methods have evolved from classical statistical models to automated, and neural…
We present an implementation of interval analysis and mixed monotone interval reachability analysis as function transforms in Python, fully composable with the computational framework JAX. The resulting toolbox inherits several key features…
This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular…
We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and…
The purpose of this paper is to show how existing scientific software can be parallelized using a separate thin layer of Python code where all parallel communication is implemented. We provide specific examples on such layers of code, and…
jq is a widely used tool that provides a programming language to manipulate JSON data. However, the jq language is currently only specified by its implementation, making it difficult to reason about its behaviour. To this end, we provide a…
unxt is a Python package for unit-aware computing with JAX. unxt is built on top of quax, which provides a framework for building array-like objects that can be used with JAX. unxt extends quax to provide support for unit-aware computing…
Context-oriented programming is an emerging paradigm addressing at the language level the issue of dynamic software adaptation and modularization of context-specific concerns. In this paper we propose JavaCtx, a tool which employs coding…
JaxoDraw is a Feynman graph plotting tool written in Java. It has a complete graphical user interface that allows all actions to be carried out via mouse click-and-drag operations in a WYSIWYG fashion. Graphs may be exported to…
Awkward Array is a library for performing NumPy-like computations on nested, variable-sized data, enabling array-oriented programming on arbitrary data structures in Python. However, imperative (procedural) solutions can sometimes be easier…