jax.lax.gather

Warning

This page was created from a pull request (#9655).

jax.lax.gatherΒΆ

jax.lax.gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False, mode=None, fill_value=None)[source]ΒΆ

Gather operator.

Wraps XLA’s Gather operator.

The semantics of gather are complicated, and its API might change in the future. For most use cases, you should prefer Numpy-style indexing (e.g., x[:, (1,4,7), …]), rather than using gather directly.

Parameters
  • operand (Any) – an array from which slices should be taken

  • start_indices (Any) – the indices at which slices should be taken

  • dimension_numbers (GatherDimensionNumbers) – a lax.GatherDimensionNumbers object that describes how dimensions of operand, start_indices and the output relate.

  • slice_sizes (Sequence[Union[int, Any]]) – the size of each slice. Must be a sequence of non-negative integers with length equal to ndim(operand).

  • indices_are_sorted (bool) – whether indices is known to be sorted. If true, may improve performance on some backends.

  • unique_indices (bool) – whether the indices in operand are guaranteed to not overlap with each other. If true, may improve performance on some backends.

  • mode (Union[str, GatherScatterMode, None]) – how to handle indices that are out of bounds: when set to 'clip', indices are clamped so that the slice is within bounds, and when set to 'fill' or 'drop' gather returns a slice full of fill_value for the affected slice. The behavior for out-of-bounds indices when set to 'promise_in_bounds' is implementation-defined.

  • fill_value – the fill value to return for out-of-bounds slices when mode is 'fill'. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.

Return type

Any

Returns

An array containing the gather output.