jax.make_jaxpr
Warning
This page was created from a pull request (#9655).
jax.make_jaxpr¶
- jax.make_jaxpr(fun, static_argnums=(), axis_env=None, return_shape=False, abstracted_axes=None)[source]¶
Creates a function that produces its jaxpr given example args.
- Parameters
fun (
Callable) – The function whosejaxpris to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof.static_argnums (
Union[int,Iterable[int]]) – See thejax.jit()docstring.axis_env (
Optional[Sequence[Tuple[Any,int]]]) – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications ofjax.pmap().return_shape (
bool) – Optional boolean, defaults toFalse. IfTrue, the wrapped function returns a pair where the first element is the XLA computation and the second element is a pytree with the same structure as the output offunand where the leaves are objects withshape,dtype, andnamed_shapeattributes representing the corresponding types of the output leaves.
- Return type
Callable[…,ClosedJaxpr]- Returns
A wrapped version of
funthat when applied to example arguments returns aClosedJaxprrepresentation offunon those arguments. If the argumentreturn_shapeisTrue, then the returned function instead returns a pair where the first element is theClosedJaxprrepresentation offunand the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output offun.
A
jaxpris JAX’s intermediate representation for program traces. Thejaxprlanguage is based on the simply-typed first-order lambda calculus with let-bindings.make_jaxpr()adapts a function to return itsjaxpr, which we can inspect to understand what JAX is doing internally. Thejaxprreturned is a trace offunabstracted toShapedArraylevel. Other levels of abstraction exist internally.We do not describe the semantics of the
jaxprlanguage in detail here, but instead give a few examples.>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> print(f(3.0)) -0.83602 >>> jax.make_jaxpr(f)(3.0) { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a:f32[]. let b:f32[] = cos a c:f32[] = sin a _:f32[] = sin b d:f32[] = cos b e:f32[] = mul 1.0 d f:f32[] = neg e g:f32[] = mul f c in (g,) }