jax.random.rademacher
-
jax.random.rademacher(key, shape, dtype=<class 'numpy.int64'>)[source]
Sample from a Rademacher distribution.
- Parameters
key (Union[Any, PRNGKeyArray]) β a PRNG key.
shape (Sequence[int]) β The shape of the returned samples.
dtype (Any) β The type used for samples.
- Return type
ndarray
- Returns
A jnp.array of samples, of shape shape. Each element in the output has
a 50% change of being 1 or -1.