jax.random.multivariate_normal
Warning
This page was created from a pull request (#9655).
jax.random.multivariate_normal¶
- jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=<class 'numpy.float64'>, method='cholesky')[source]¶
Sample multivariate normal random values with given mean and covariance.
- Parameters
key (
Union[Any,PRNGKeyArray]) – a PRNG key used as the random key.mean (
Any) – a mean vector of shape(..., n).cov (
Any) – a positive definite covariance matrix of shape(..., n, n). The batch shape...must be broadcast-compatible with that ofmean.shape (
Optional[Sequence[int]]) – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible withmean.shape[:-1]andcov.shape[:-2]. The default (None) produces a result batch shape by broadcasting together the batch shapes ofmeanandcov.dtype (
Any) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).method (
str) – optional, a method to compute the factor ofcov. Must be one of ‘svd’, eigh, and ‘cholesky’. Default ‘cholesky’.
- Return type
- Returns
A random array with the specified dtype and shape given by
shape + mean.shape[-1:]ifshapeis not None, or elsebroadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:].