jax.ops.index_update
Warning
This page was created from a pull request (#9655).
jax.ops.index_update¶
- jax.ops.index_update(x, idx, y, indices_are_sorted=False, unique_indices=False)[source]¶
Pure equivalent of
x[idx] = y
.Deprecated since version 0.2.22: Prefer the use of
jax.numpy.ndarray.at
.Returns the value of x that would result from the NumPy-style
indexed assignment
:x[idx] = y
Note the index_update operator is pure; x itself is not modified, instead the new value that x would have taken is returned.
Unlike NumPy’s
x[idx] = y
, if multiple indices refer to the same location it is undefined which update is chosen; JAX may choose the order of updates arbitrarily and nondeterministically (e.g., due to concurrent updates on some hardware platforms).- Parameters
x (
Any
) – an array with the values to be updated.idx (
Union
[None
,int
,slice
,Sequence
[int
],Any
,Tuple
[Union
[None
,int
,slice
,Sequence
[int
],Any
], …]]) – a Numpy-style index, consisting of None, integers, slice objects, ellipses, ndarrays with integer dtypes, or a tuple of the above. A convenient syntactic sugar for forming indices is via thejax.ops.index
object.y (
Union
[Any
,complex
,float
,int
,number
]) – the array of updates. y must be broadcastable to the shape of the array that would be returned by x[idx].indices_are_sorted (
bool
) – whether idx is known to be sortedunique_indices (
bool
) – whether idx is known to be free of duplicates
- Return type
- Returns
An array.
>>> x = jax.numpy.ones((5, 6)) >>> jax.ops.index_update(x, jnp.index_exp[::2, 3:], 6.) DeviceArray([[1., 1., 1., 6., 6., 6.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 6., 6., 6.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 6., 6., 6.]], dtype=float32)