jax.experimental.maps.mesh
Warning
This page was created from a pull request (#9655).
jax.experimental.maps.mesh¶
- jax.experimental.maps.mesh(devices, axis_names)[source]¶
Declare the hardware resources available in the scope of this manager.
In particular, all
axis_names
become valid resource names inside the managed block and can be used e.g. in theaxis_resources
argument ofxmap()
.If you are compiling in multiple threads, make sure that the
with mesh
context manager is inside the function that the threads will execute.- Parameters
devices (
ndarray
) – A NumPy ndarray object containing JAX device objects (as obtained e.g. fromjax.devices()
).axis_names (
Sequence
[Hashable
]) – A sequence of resource axis names to be assigned to the dimensions of thedevices
argument. Its length should match the rank ofdevices
.
Example:
devices = np.array(jax.devices())[:4].reshape((2, 2)) with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y' distributed_out = xmap( jnp.vdot, in_axes=({0: 'left', 1: 'right'}), out_axes=['left', 'right', ...], axis_resources={'left': 'x', 'right': 'y'})(x, x.T)