Warning

This page was created from a pull request (#9655).

jax.ops package

Indexed update operators

JAX is intended to be used with a functional style of programming, and does not support NumPy-style indexed assignment directly. Instead, JAX provides alternative pure functional operators for indexed updates to arrays.

JAX array types have a property at, which can be used as follows (where idx is a NumPy index expression).

Alternate syntax

Equivalent in-place expression

x.at[idx].get()

x[idx]

x.at[idx].set(y)

x[idx] = y

x.at[idx].add(y)

x[idx] += y

x.at[idx].multiply(y)

x[idx] *= y

x.at[idx].divide(y)

x[idx] /= y

x.at[idx].power(y)

x[idx] **= y

x.at[idx].min(y)

x[idx] = np.minimum(x[idx], y)

x.at[idx].max(y)

x[idx] = np.maximum(x[idx], y)

None of these expressions modify the original x; instead they return a modified copy of x. However, inside a jit() compiled function, expressions like x = x.at[idx].set(y) are guaranteed to be applied in-place.

Unlike NumPy in-place operations such as x[idx] += y, if multiple indices refer to the same location, all updates will be applied (NumPy would only apply the last update, rather than applying all updates.) The order in which conflicting updates are applied is implementation-defined and may be nondeterministic (e.g., due to concurrency on some hardware platforms).

By default, JAX assumes that all indices are in-bounds. There is experimental support for giving more precise semantics to out-of-bounds indexed accesses, via the mode parameter to functions such as get and set. Valid values for mode include "clip", which means that out-of-bounds indices will be clamped into range, and "fill"/"drop", which are aliases and mean that out-of-bounds reads will be filled with a scalar fill_value, and out-of-bounds writes will be discarded.

Indexed update functions (deprecated)

The following functions are aliases for the x.at[idx].set(y) style operators. Use the x.at[idx] operators instead.

index

Helper object for building indexes for indexed update functions.

index_update(x, idx, y[, ...])

Pure equivalent of x[idx] = y.

index_add(x, idx, y[, indices_are_sorted, ...])

Pure equivalent of x[idx] += y.

index_mul(x, idx, y[, indices_are_sorted, ...])

Pure equivalent of x[idx] *= y.

index_min(x, idx, y[, indices_are_sorted, ...])

Pure equivalent of x[idx] = minimum(x[idx], y).

index_max(x, idx, y[, indices_are_sorted, ...])

Pure equivalent of x[idx] = maximum(x[idx], y).

Other operators

segment_max(data, segment_ids[, ...])

Computes the maximum within segments of an array.

segment_min(data, segment_ids[, ...])

Computes the minimum within segments of an array.

segment_prod(data, segment_ids[, ...])

Computes the product within segments of an array.

segment_sum(data, segment_ids[, ...])

Computes the sum within segments of an array.