logo

Getting Started

  • JAX Quickstart
  • How to Think in JAX
  • 🔪 JAX - The Sharp Bits 🔪
  • Tutorial: JAX 101
    • JAX As Accelerated NumPy
    • Just In Time Compilation with JAX
    • Automatic Vectorization in JAX
    • Advanced Automatic Differentiation in JAX
    • Pseudo Random Numbers in JAX
    • Working with Pytrees
    • Parallel Evaluation in JAX
    • Stateful Computations in JAX
    • Introduction to pjit

Reference Documentation

  • JAX Frequently Asked Questions (FAQ)
  • Asynchronous dispatch
  • Understanding Jaxprs
  • Convolutions in JAX
  • Pytrees
  • Type promotion semantics
  • JAX Errors
  • JAX Glossary of Terms
  • Change log

Advanced JAX Tutorials

  • The Autodiff Cookbook
  • Autobatching log-densities example
  • Training a Simple Neural Network, with tensorflow/datasets Data Loading
  • Custom derivative rules for JAX-transformable Python functions
  • How JAX primitives work
  • Writing custom Jaxpr interpreters in JAX
  • Training a Simple Neural Network, with PyTorch Data Loading
  • Named axes and easy-to-revise parallelism
  • Using JAX in multi-host and multi-process environments

Notes

  • Python and NumPy version support policy
  • Concurrency
  • GPU memory allocation
  • Profiling JAX programs
  • Device Memory Profiling
  • Rank promotion warning
  • custom_vjp and nondiff_argnums update guide

Developer documentation

  • Contributing to JAX
  • Building from source
  • Internal APIs
  • Autodidax: JAX core from scratch
  • Design Notes
    • Custom JVP/VJP rules for JAX-transformable functions
    • Jax and Jaxlib versioning
    • Omnistaging
    • JAX PRNG Design
    • Design of Type Promotion Semantics for JAX

API documentation

  • Public API: jax package
    • jax.numpy package
    • jax.scipy package
    • JAX configuration
    • jax.dlpack module
    • jax.distributed module
    • jax.example_libraries package
      • jax.example_libraries.optimizers module
      • jax.example_libraries.stax module
    • jax.experimental package
      • jax.experimental.ann module
      • jax.experimental.global_device_array module
      • jax.experimental.host_callback module
      • jax.experimental.loops module
      • jax.experimental.maps module
      • jax.experimental.pjit module
      • jax.experimental.sparse module
    • jax.flatten_util package
    • jax.image package
    • jax.lax package
    • jax.nn package
      • jax.nn.initializers package
    • jax.ops package
    • jax.profiler module
    • jax.random package
    • jax.tree_util package
    • jax.lib package
Theme by the Executable Book Project

Python Module Index

j
 
j
- jax
    jax.core
    jax.distributed
    jax.dlpack
    jax.example_libraries
    jax.example_libraries.optimizers
    jax.example_libraries.stax
    jax.experimental.ann
    jax.experimental.global_device_array
    jax.experimental.host_callback
    jax.experimental.loops
    jax.experimental.maps
    jax.experimental.pjit
    jax.experimental.sparse
    jax.flatten_util
    jax.image
    jax.lax
    jax.lax.linalg
    jax.nn
    jax.nn.initializers
    jax.numpy
    jax.numpy.fft
    jax.numpy.linalg
    jax.ops
    jax.profiler
    jax.random
    jax.scipy.fft
    jax.scipy.linalg
    jax.scipy.ndimage
    jax.scipy.optimize
    jax.scipy.signal
    jax.scipy.sparse.linalg
    jax.scipy.special
    jax.scipy.stats.bernoulli
    jax.scipy.stats.beta
    jax.scipy.stats.betabinom
    jax.scipy.stats.cauchy
    jax.scipy.stats.chi2
    jax.scipy.stats.dirichlet
    jax.scipy.stats.expon
    jax.scipy.stats.gamma
    jax.scipy.stats.geom
    jax.scipy.stats.laplace
    jax.scipy.stats.logistic
    jax.scipy.stats.multivariate_normal
    jax.scipy.stats.norm
    jax.scipy.stats.pareto
    jax.scipy.stats.poisson
    jax.scipy.stats.t
    jax.scipy.stats.uniform
    jax.tree_util

By The JAX authors
© Copyright 2020, Google LLC. NumPy and SciPy documentation are copyright the respective authors..