jax.random.multivariate_normal
Warning
This page was created from a pull request (#9655).
jax.random.multivariate_normal¶
- jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=<class 'numpy.float64'>, method='cholesky')[source]¶
Sample multivariate normal random values with given mean and covariance.
- Parameters
key (
Union
[Any
,PRNGKeyArray
]) – a PRNG key used as the random key.mean (
Any
) – a mean vector of shape(..., n)
.cov (
Any
) – a positive definite covariance matrix of shape(..., n, n)
. The batch shape...
must be broadcast-compatible with that ofmean
.shape (
Optional
[Sequence
[int
]]) – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible withmean.shape[:-1]
andcov.shape[:-2]
. The default (None) produces a result batch shape by broadcasting together the batch shapes ofmean
andcov
.dtype (
Any
) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).method (
str
) – optional, a method to compute the factor ofcov
. Must be one of ‘svd’, eigh, and ‘cholesky’. Default ‘cholesky’.
- Return type
- Returns
A random array with the specified dtype and shape given by
shape + mean.shape[-1:]
ifshape
is not None, or elsebroadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]
.