jax.ops package
Contents
Warning
This page was created from a pull request (#9655).
jax.ops package¶
Indexed update operators¶
JAX is intended to be used with a functional style of programming, and does not support NumPy-style indexed assignment directly. Instead, JAX provides alternative pure functional operators for indexed updates to arrays.
JAX array types have a property at
, which can be used as
follows (where idx
is a NumPy index expression).
Alternate syntax |
Equivalent in-place expression |
---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
None of these expressions modify the original x; instead they return
a modified copy of x. However, inside a jit()
compiled function,
expressions like x = x.at[idx].set(y)
are guaranteed to be applied in-place.
Unlike NumPy in-place operations such as x[idx] += y
, if multiple
indices refer to the same location, all updates will be applied (NumPy would
only apply the last update, rather than applying all updates.) The order
in which conflicting updates are applied is implementation-defined and may be
nondeterministic (e.g., due to concurrency on some hardware platforms).
By default, JAX assumes that all indices are in-bounds. There is experimental
support for giving more precise semantics to out-of-bounds indexed accesses,
via the mode
parameter to functions such as get
and set
. Valid
values for mode
include "clip"
, which means that out-of-bounds indices
will be clamped into range, and "fill"
/"drop"
, which are aliases and
mean that out-of-bounds reads will be filled with a scalar fill_value
,
and out-of-bounds writes will be discarded.
Indexed update functions (deprecated)¶
The following functions are aliases for the x.at[idx].set(y)
style operators. Use the x.at[idx]
operators instead.
Helper object for building indexes for indexed update functions. |
|
|
Pure equivalent of |
|
Pure equivalent of |
|
Pure equivalent of |
|
Pure equivalent of |
|
Pure equivalent of |
Other operators¶
|
Computes the maximum within segments of an array. |
|
Computes the minimum within segments of an array. |
|
Computes the product within segments of an array. |
|
Computes the sum within segments of an array. |