Contents

jax.experimental.global_device_array module

Contents

Warning

This page was created from a pull request (#9655).

jax.experimental.global_device_array module

API

GlobalDeviceArray(global_shape, global_mesh, ...)

A logical array with data sharded across multiple devices and processes.

Shard(device, index, replica_id[, data])

A single data shard of a GlobalDeviceArray.