jax.lax.ppermute
Warning
This page was created from a pull request (#9655).
jax.lax.ppermuteΒΆ
- jax.lax.ppermute(x, axis_name, perm)[source]ΒΆ
Perform a collective permutation according to the permutation
perm
.If
x
is a pytree then the result is equivalent to mapping this function to each leaf in the tree.This function is an analog of the CollectivePermute XLA HLO.
- Parameters
x β array(s) with a mapped axis named
axis_name
.axis_name β hashable Python object used to name a pmapped axis (see the
jax.pmap()
documentation for more details).perm β list of pairs of ints, representing
(source_index, destination_index)
pairs that encode how the mapped axis namedaxis_name
should be shuffled. The integer values are treated as indices into the mapped axisaxis_name
. Any two pairs should not have the same source index or the same destination index. For each index of the axisaxis_name
that does not correspond to a destination index inperm
, the corresponding values in the result are filled with zeros of the appropriate type.
- Returns
Array(s) with the same shape as
x
with slices along the axisaxis_name
gathered fromx
according to the permutationperm
.