jax.random.permutation

Warning

This page was created from a pull request (#9655).

jax.random.permutationΒΆ

jax.random.permutation(key, x, axis=0, independent=False)[source]ΒΆ

Returns a randomly permuted array or range.

Parameters
  • key (Union[Any, PRNGKeyArray]) – a PRNG key used as the random key.

  • x (Union[int, Any]) – int or array. If x is an integer, randomly shuffle np.arange(x). If x is an array, randomly shuffle its elements.

  • axis (int) – int, optional. The axis which x is shuffled along. Default is 0.

  • independent (bool) – bool, optional. If set to True, each individual vector along the given axis is shuffled independently. Default is False.

Return type

ndarray

Returns

A shuffled version of x or array range