jax.experimental.global_device_array.GlobalDeviceArray

Warning

This page was created from a pull request (#9655).

jax.experimental.global_device_array.GlobalDeviceArray¶

class jax.experimental.global_device_array.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, device_buffers, _gda_fast_path_args=None)[source]¶

A logical array with data sharded across multiple devices and processes.

If you’re not already familiar with JAX’s multi-process programming model, please read https://jax.readthedocs.io/en/latest/multi_process.html.

A GlobalDeviceArray (GDA) can be thought of as a view into a single logical array sharded across processes. The logical array is the “global” array, and each process has a GlobalDeviceArray object referring to the same global array (similarly to how each process runs a multi-process pmap or pjit). Each process can access the shape, dtype, etc. of the global array via the GDA, pass the GDA into multi-process pjits, and get GDAs as pjit outputs (coming soon: xmap and pmap). However, each process can only directly access the shards of the global array data stored on its local devices.

GDAs can help manage the inputs and outputs of multi-process computations. A GDA keeps track of which shard of the global array belongs to which device, and provides callback-based APIs to materialize the correct shard of the data needed for each local device of each process.

A GDA consists of data shards. Each shard is stored on a different device. There are local shards and global shards. Local shards are those on local devices, and the data is visible to the current process. Global shards are those across all devices (including local devices), and the data isn’t visible if the shard is on a non-local device with respect to the current process. Please see the Shard class to see what information is stored inside that data structure.

Note: to make pjit output GlobalDeviceArrays, set the environment variable JAX_PARALLEL_FUNCTIONS_OUTPUT_GDA=true or add the following to your code: jax.config.update('jax_parallel_functions_output_gda', True)

Parameters
  • global_shape (Tuple[int, …]) – The global shape of the array.

  • global_mesh (Mesh) – The global mesh representing devices across multiple processes.

  • mesh_axes (Sequence[Union[str, Tuple[str], None]]) –

    A sequence with length less than or equal to the rank of the global array (i.e. the length of the global shape). Each element can be:

    • An axis name of global_mesh, indicating that the corresponding global array axis is partitioned across the given device axis of global_mesh.

    • A tuple of axis names of global_mesh. This is like the above option except the global array axis is partitioned across the product of axes named in the tuple.

    • None indicating that the corresponding global array axis is not partitioned.

    For more information, please see: https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#more-information-on-partitionspec

  • device_buffers (Sequence[DeviceArray]) – DeviceArrays that are on the local devices of global_mesh.

shape¶

Global shape of the array.

dtype¶

Dtype of the global array.

ndim¶

Number of array dimensions in the global shape.

size¶

Number of elements in the global array.

local_shards¶

List of Shard on the local devices of the current process. Data is materialized for all local shards.

global_shards¶

List of all Shard of the global array. Data isn’t available if a shard is on a non-local device with respect to the current process.

is_fully_replicated¶

True if the full array value is present on all devices of the global mesh.

Example:

# Logical mesh is (hosts, devices)
assert global_mesh.shape == {'x': 4, 'y': 8}

global_input_shape = (64, 32)
mesh_axes = P('x', 'y')

# Dummy example data; in practice we wouldn't necessarily materialize global data
# in a single process.
global_input_data = np.arange(
    np.prod(global_input_shape)).reshape(global_input_shape)

def get_local_data_slice(index):
  # index will be a tuple of slice objects, e.g. (slice(0, 16), slice(0, 4))
  # This method will be called per-local device from the GDA constructor.
  return global_input_data[index]

gda = GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, get_local_data_slice)

f = pjit(lambda x: x @ x.T, out_axis_resources = P('y', 'x'))

with mesh(global_mesh.shape, global_mesh.axis_names):
  out = f(gda)

print(type(out))  # GlobalDeviceArray
print(out.shape)  # global shape == (64, 64)
print(out.local_shards[0].data)  # Access the data on a single local device,
                                # e.g. for checkpointing
print(out.local_shards[0].data.shape)  # per-device shape == (8, 16)
print(out.local_shards[0].index) # Numpy-style index into the global array that
                                # this data shard corresponds to

# `out` can be passed to another pjit call, out.local_shards can be used to
# export the data to non-jax systems (e.g. for checkpointing or logging), etc.
Parameters

_gda_fast_path_args (Optional[_GdaFastPathArgs]) –

__init__(global_shape, global_mesh, mesh_axes, device_buffers, _gda_fast_path_args=None)[source]¶
Parameters

Methods

__init__(global_shape, global_mesh, ...[, ...])

param global_shape

from_batched_callback(global_shape, ...)

Constructs a GlobalDeviceArray via batched data fetched from data_callback.

from_batched_callback_with_devices(...)

Constructs a GlobalDeviceArray via batched DeviceArrays fetched from data_callback.

from_callback(global_shape, global_mesh, ...)

Constructs a GlobalDeviceArray via data fetched from data_callback.

local_data(index)

rtype

DeviceArray

Attributes

global_shards

rtype

Sequence[Shard]

is_fully_replicated

rtype

bool

local_shards

rtype

Sequence[Shard]

ndim

shape

rtype

Tuple[int, ...]

size