jax.numpy.where
Warning
This page was created from a pull request (#9655).
jax.numpy.whereΒΆ
- jax.numpy.where(condition, x=None, y=None, *, size=None, fill_value=None)[source]ΒΆ
Return elements chosen from x or y depending on condition.
LAX-backend implementation of
where().At present, JAX does not support JIT-compilation of the single-argument form of
jax.numpy.where()because its output shape is data-dependent. The three-argument form does not have a data-dependent shape and can be JIT-compiled successfully. Alternatively, you can use the optionalsizekeyword to statically specify the expected size of the output.Original docstring below.
Note
When only condition is provided, this function is a shorthand for
np.asarray(condition).nonzero(). Using nonzero directly should be preferred, as it behaves correctly for subclasses. The rest of this documentation covers only the case where all three arguments are provided.- Parameters
condition (array_like, bool) β Where True, yield x, otherwise yield y.
x (array_like) β Values from which to choose. x, y and condition need to be broadcastable to some shape.
y (array_like) β Values from which to choose. x, y and condition need to be broadcastable to some shape.
size (int, optional) β Only referenced when
xandyareNone. If specified, the indices of the firstsizeelements of the result will be returned. If there are fewer elements thansizeindicates, the return value will be padded withfill_value.fill_value (array_like, optional) β When
sizeis specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value, which defaults to zero.
- Returns
out β An array with elements from x where condition is True, and elements from y elsewhere.
- Return type