jax.random.split
-
jax.random.split(key, num=2)[source]
Splits a PRNG key into num new keys by adding a leading axis.
- Parameters
key (Union
[Any
, PRNGKeyArray
]) – a PRNG key (from PRNGKey
, split
, fold_in
).
num (int
) – optional, a positive integer indicating the number of keys to produce
(default 2).
- Return type
Union
[Any
, PRNGKeyArray
]
- Returns
An array-like object of num new PRNG keys.