English

JAXbind: Bind any function to JAX

Instrumentation and Methods for Astrophysics 2024-06-28 v2 Machine Learning Computation

Abstract

JAX is widely used in machine learning and scientific computing, the latter of which often relies on existing high-performance code that we would ideally like to incorporate into JAX. Reimplementing the existing code in JAX is often impractical and the existing interface in JAX for binding custom code either limits the user to a single Jacobian product or requires deep knowledge of JAX and its C++ backend for general Jacobian products. With JAXbind we drastically reduce the effort required to bind custom functions implemented in other programming languages with full support for Jacobian-vector products and vector-Jacobian products to JAX. Specifically, JAXbind provides an easy-to-use Python interface for defining custom, so-called JAX primitives. Via JAXbind, any function callable from Python can be exposed as a JAX primitive. JAXbind allows a user to interface the JAX function transformation engine with custom derivatives and batching rules, enabling all JAX transformations for the custom primitive.

Cite

@article{arxiv.2403.08847,
  title  = {JAXbind: Bind any function to JAX},
  author = {Jakob Roth and Martin Reinecke and Gordian Edenhofer},
  journal= {arXiv preprint arXiv:2403.08847},
  year   = {2024}
}

Comments

4 pages, Github: https://github.com/NIFTy-PPL/JAXbind