jax.numpy.flatnonzero
Warning
This page was created from a pull request (#9655).
jax.numpy.flatnonzero¶
- jax.numpy.flatnonzero(a, *, size=None, fill_value=None)[source]¶
Return indices that are non-zero in the flattened version of a.
LAX-backend implementation of
flatnonzero().Because the size of the output of
nonzerois data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsizeargument which must be specified statically forjnp.nonzeroto be used within some of JAX’s transformations.Original docstring below.
This is equivalent to np.nonzero(np.ravel(a))[0].
- Parameters
a (array_like) – Input data.
size (int, optional) – If specified, the indices of the first
sizeTrue elements will be returned. If there are fewer unique elements thansizeindicates, the return value will be padded withfill_value.fill_value (array_like, optional) – When
sizeis specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value, which defaults to zero.
- Returns
res – Output array, containing the indices of the elements of a.ravel() that are non-zero.
- Return type