jax.lax.map
Warning
This page was created from a pull request (#9655).
jax.lax.map¶
- jax.lax.map(f, xs)[source]¶
Map a function over leading array axes.
Like Python’s builtin map, except inputs and outputs are in the form of stacked arrays. Consider using the
jax.vmaptransform instead, unless you need to apply a function element by element for reduced memory usage or heterogeneous computation with other control flow primitives.When
xsis an array type, the semantics ofmapare given by this Python implementation:def map(f, xs): return np.stack([f(x) for x in xs])
Like
scan,mapis implemented in terms of JAX primitives so many of the same advantages over a Python loop apply:xsmay be an arbitrary nested pytree type, and the mapped computation is compiled only once.- Parameters
f – a Python function to apply element-wise over the first axis or axes of
xs.xs – values over which to map along the leading axis.
- Returns
Mapped values.