JAX reference documentation
Contents
Warning
This page was created from a pull request (#9655).
JAX reference documentation¶
JAX is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more.
Getting Started
Reference Documentation
Advanced JAX Tutorials
- The Autodiff Cookbook
- Autobatching log-densities example
- Training a Simple Neural Network, with tensorflow/datasets Data Loading
- Custom derivative rules for JAX-transformable Python functions
- How JAX primitives work
- Writing custom Jaxpr interpreters in JAX
- Training a Simple Neural Network, with PyTorch Data Loading
- Named axes and easy-to-revise parallelism
- Using JAX in multi-host and multi-process environments
Notes
Developer documentation
API documentation
- Public API: jax package
- Subpackages
- jax.numpy package
- jax.scipy package
- JAX configuration
- jax.dlpack module
- jax.distributed module
- jax.example_libraries package
- jax.experimental package
- jax.flatten_util package
- jax.image package
- jax.lax package
- jax.nn package
- jax.ops package
- jax.profiler module
- jax.random package
- jax.tree_util package
- Just-in-time compilation (
jit
) - Automatic differentiation
- Vectorization (
vmap
) - Parallelization (
pmap
)
- Subpackages