jax.random.split
-
jax.random.split(key, num=2)[source]
Splits a PRNG key into num new keys by adding a leading axis.
- Parameters
key (Union[Any, PRNGKeyArray]) – a PRNG key (from PRNGKey, split, fold_in).
num (int) – optional, a positive integer indicating the number of keys to produce
(default 2).
- Return type
Union[Any, PRNGKeyArray]
- Returns
An array-like object of num new PRNG keys.