jax.device_put_sharded
Warning
This page was created from a pull request (#9655).
jax.device_put_sharded¶
- jax.device_put_sharded(shards, devices)[source]¶
Transfer array shards to specified devices and form ShardedDeviceArray(s).
- Parameters
shards (
Sequence[Any]) – A sequence of arrays, scalars, or (nested) standard Python containers thereof representing the shards to be stacked together to form the output. The length ofshardsmust equal the length ofdevices.devices (
Sequence[Device]) – A sequence ofDeviceinstances representing the devices to which corresponding shards inshardswill be transferred.
- Returns
A ShardedDeviceArray or (nested) Python container thereof representing the elements of
shardsstacked together, with each shard backed by physical device memory specified by the corresponding entry indevices.
Examples
Passing a list of arrays for
shardsresults in a sharded array containing a stacked version of the inputs:>>> import jax >>> devices = jax.local_devices() >>> x = [jax.numpy.ones(5) for device in devices] >>> y = jax.device_put_sharded(x, devices) >>> np.allclose(y, jax.numpy.stack(x)) True
Passing a list of nested container objects with arrays at the leaves for
shardscorresponds to stacking the shards at each leaf. This requires all entries in the list to have the same tree structure:>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))] >>> y = jax.device_put_sharded(x, devices) >>> type(y) <class 'tuple'> >>> y0 = jax.device_put_sharded([a for a, b in x], devices) >>> y1 = jax.device_put_sharded([b for a, b in x], devices) >>> np.allclose(y[0], y0) True >>> np.allclose(y[1], y1) True
See also
device_put
device_put_replicated