jax.random.shuffle
-
jax.random.shuffle(key, x, axis=0)[source]
Shuffle the elements of an array uniformly at random along an axis.
- Parameters
key (Union
[Any
, PRNGKeyArray
]) β a PRNG key used as the random key.
x (Any
) β the array to be shuffled.
axis (int
) β optional, an int axis along which to shuffle (default 0).
- Return type
ndarray
- Returns
A shuffled version of x.