jax.experimental.global_device_array.Shard
Warning
This page was created from a pull request (#9655).
jax.experimental.global_device_array.Shard¶
- class jax.experimental.global_device_array.Shard(device, index, replica_id, data=None)[source]¶
A single data shard of a GlobalDeviceArray.
- Parameters
device (
Device) – Which device this shard resides on.index (
Tuple[slice, …]) – The index into the global array of this shard.replica_id (
int) – Integer id indicating which replica of the global array this shard is part of. Always 0 for fully sharded data (i.e. when there’s only 1 replica).data (
Optional[DeviceArray]) – The data of this shard. None ifdeviceis non-local.
- __init__(device, index, replica_id, data=None)¶
Methods
__init__(device, index, replica_id[, data])- param device
Attributes
datadeviceindexreplica_id