jax.ops.index_add
Warning
This page was created from a pull request (#9655).
jax.ops.index_addΒΆ
- jax.ops.index_add(x, idx, y, indices_are_sorted=False, unique_indices=False)[source]ΒΆ
Pure equivalent of
x[idx] += y
.Deprecated since version 0.2.22: Prefer the use of
jax.numpy.ndarray.at
.Returns the value of x that would result from the NumPy-style
indexed assignment
:x[idx] += y
Note the index_add operator is pure; x itself is not modified, instead the new value that x would have taken is returned.
Unlike the NumPy code
x[idx] += y
, if multiple indices refer to the same location the updates will be summed. (NumPy would only apply the last update, rather than summing the updates.) The order in which conflicting updates are applied is implementation-defined and may be nondeterministic (e.g., due to concurrency on some hardware platforms).- Parameters
x (
Any
) β an array with the values to be updated.idx (
Union
[None
,int
,slice
,Sequence
[int
],Any
,Tuple
[Union
[None
,int
,slice
,Sequence
[int
],Any
], β¦]]) β a Numpy-style index, consisting of None, integers, slice objects, ellipses, ndarrays with integer dtypes, or a tuple of the above. A convenient syntactic sugar for forming indices is via thejax.ops.index
object.y (
Union
[Any
,complex
,float
,int
,number
]) β the array of updates. y must be broadcastable to the shape of the array that would be returned by x[idx].indices_are_sorted (
bool
) β whether idx is known to be sortedunique_indices (
bool
) β whether idx is known to be free of duplicates
- Return type
- Returns
An array.
>>> x = jax.numpy.ones((5, 6)) >>> jax.ops.index_add(x, jnp.index_exp[2:4, 3:], 6.) DeviceArray([[1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 7., 7., 7.], [1., 1., 1., 7., 7., 7.], [1., 1., 1., 1., 1., 1.]], dtype=float32)