jax.random.categorical
Warning
This page was created from a pull request (#9655).
jax.random.categoricalΒΆ
- jax.random.categorical(key, logits, axis=- 1, shape=None)[source]ΒΆ
Sample random values from categorical distributions.
- Parameters
key (
Union[Any,PRNGKeyArray]) β a PRNG key used as the random key.logits (
Any) β Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.axis (
int) β Axis along which logits belong to the same categorical distribution.shape (
Optional[Sequence[int]]) β Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible withnp.delete(logits.shape, axis). The default (None) produces a result shape equal tonp.delete(logits.shape, axis).
- Return type
- Returns
A random array with int dtype and shape given by
shapeifshapeis not None, or elsenp.delete(logits.shape, axis).