jax.numpy.nonzero
Warning
This page was created from a pull request (#9655).
jax.numpy.nonzero¶
- jax.numpy.nonzero(a, *, size=None, fill_value=None)[source]¶
Return the indices of the elements that are non-zero.
LAX-backend implementation of
nonzero().Because the size of the output of
nonzerois data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsizeargument which must be specified statically forjnp.nonzeroto be used within some of JAX’s transformations.Original docstring below.
Returns a tuple of arrays, one for each dimension of a, containing the indices of the non-zero elements in that dimension. The values in a are always tested and returned in row-major, C-style order.
To group the indices by element, rather than dimension, use argwhere, which returns a row for each non-zero element.
Note
When called on a zero-d array or scalar,
nonzero(a)is treated asnonzero(atleast_1d(a)).Deprecated since version 1.17.0: Use atleast_1d explicitly if this behavior is deliberate.
- Parameters
a (array_like) – Input array.
size (int, optional) – If specified, the indices of the first
sizeTrue elements will be returned. If there are fewer unique elements thansizeindicates, the return value will be padded withfill_value.fill_value (array_like, optional) – When
sizeis specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value, which defaults to zero.
- Returns
tuple_of_arrays – Indices of elements that are non-zero.
- Return type