jax.experimental.global_device_array.Shard

Warning

This page was created from a pull request (#9655).

jax.experimental.global_device_array.Shard¶

class jax.experimental.global_device_array.Shard(device, index, replica_id, data=None)[source]¶

A single data shard of a GlobalDeviceArray.

Parameters
  • device (Device) – Which device this shard resides on.

  • index (Tuple[slice, …]) – The index into the global array of this shard.

  • replica_id (int) – Integer id indicating which replica of the global array this shard is part of. Always 0 for fully sharded data (i.e. when there’s only 1 replica).

  • data (Optional[DeviceArray]) – The data of this shard. None if device is non-local.

__init__(device, index, replica_id, data=None)¶
Parameters

Methods

__init__(device, index, replica_id[, data])

param device

Attributes

data

device

index

replica_id