Getting Started
Reference Documentation
Advanced JAX Tutorials
Notes
custom_vjp
nondiff_argnums
Developer documentation
API documentation
Warning
This page was created from a pull request (#9655).
Normalizes an array by subtracting mean and dividing by sqrt(var).
x (Any) β
Any
axis (Union[int, Tuple[int, β¦], None]) β
Union
int
Tuple
None
mean (Optional[Any]) β
Optional
variance (Optional[Any]) β
epsilon (Any) β
where (Optional[Any]) β
previous
jax.nn.logsumexp
next
jax.nn.one_hot