jax.experimental.global_device_array.Shard
Warning
This page was created from a pull request (#9655).
jax.experimental.global_device_array.Shard¶
- class jax.experimental.global_device_array.Shard(device, index, replica_id, data=None)[source]¶
A single data shard of a GlobalDeviceArray.
- Parameters
device (
Device
) – Which device this shard resides on.index (
Tuple
[slice
, …]) – The index into the global array of this shard.replica_id (
int
) – Integer id indicating which replica of the global array this shard is part of. Always 0 for fully sharded data (i.e. when there’s only 1 replica).data (
Optional
[DeviceArray
]) – The data of this shard. None ifdevice
is non-local.
- __init__(device, index, replica_id, data=None)¶
Methods
__init__
(device, index, replica_id[, data])- param device
Attributes
data
device
index
replica_id