jax.eval_shape
Warning
This page was created from a pull request (#9655).
jax.eval_shape¶
- jax.eval_shape(fun, *args, **kwargs)[source]¶
Compute the shape/dtype of
funwithout any FLOPs.This utility function is useful for performing shape inference. Its input/output behavior is defined by:
def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) return jax.tree_util.tree_map(shape_dtype_struct, out) def shape_dtype_struct(x): return ShapeDtypeStruct(x.shape, x.dtype) class ShapeDtypeStruct: __slots__ = ["shape", "dtype"] def __init__(self, shape, dtype): self.shape = shape self.dtype = dtype
In particular, the output is a pytree of objects that have
shapeanddtypeattributes, but nothing else about them is guaranteed by the API.But instead of applying
fundirectly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.Using
eval_shape()can also catch shape errors, and will raise same shape errors as evaluatingfun(*args, **kwargs).- Parameters
fun (
Callable) – The function whose output shape should be evaluated.*args – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the
shapeanddtypeattributes are accessed, only values that duck-type arrays are required, rather than real ndarrays. The duck-typed objects cannot be namedtuples because those are treated as standard Python containers. See the example below.**kwargs – a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in
args, array values need only be duck-typed to haveshapeanddtypeattributes.
For example:
>>> import jax >>> import jax.numpy as jnp >>> >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) >>> class MyArgArray(object): ... def __init__(self, shape, dtype): ... self.shape = shape ... self.dtype = jnp.dtype(dtype) ... >>> A = MyArgArray((2000, 3000), jnp.float32) >>> x = MyArgArray((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) >>> print(out.dtype) float32