jax.numpy.argwhere
Warning
This page was created from a pull request (#9655).
jax.numpy.argwhere¶
- jax.numpy.argwhere(a, *, size=None, fill_value=None)[source]¶
Find the indices of array elements that are non-zero, grouped by element.
LAX-backend implementation of
argwhere().Because the size of the output of
argwhereis data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsizeargument which must be specified statically forjnp.argwhereto be used within some of JAX’s transformations.Original docstring below.
- Parameters
a (array_like) – Input data.
size (int, optional) – If specified, the indices of the first
sizeTrue elements will be returned. If there are fewer results thansizeindicates, the return value will be padded withfill_value.fill_value (array_like, optional) – When
sizeis specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value, which defaults to zero.
- Returns
index_array – Indices of elements that are non-zero. Indices are grouped by element. This array will have shape
(N, a.ndim)whereNis the number of non-zero items.- Return type
(N, a.ndim) ndarray