jax.experimental.maps.mesh

Warning

This page was created from a pull request (#9655).

jax.experimental.maps.mesh¶

jax.experimental.maps.mesh(devices, axis_names)[source]¶

Declare the hardware resources available in the scope of this manager.

In particular, all axis_names become valid resource names inside the managed block and can be used e.g. in the axis_resources argument of xmap().

If you are compiling in multiple threads, make sure that the with mesh context manager is inside the function that the threads will execute.

Parameters
  • devices (ndarray) – A NumPy ndarray object containing JAX device objects (as obtained e.g. from jax.devices()).

  • axis_names (Sequence[Hashable]) – A sequence of resource axis names to be assigned to the dimensions of the devices argument. Its length should match the rank of devices.

Example:

devices = np.array(jax.devices())[:4].reshape((2, 2))
with mesh(devices, ('x', 'y')):  # declare a 2D mesh with axes 'x' and 'y'
  distributed_out = xmap(
    jnp.vdot,
    in_axes=({0: 'left', 1: 'right'}),
    out_axes=['left', 'right', ...],
    axis_resources={'left': 'x', 'right': 'y'})(x, x.T)