jax.numpy.roll
Warning
This page was created from a pull request (#9655).
jax.numpy.rollΒΆ
- jax.numpy.roll(a, shift, axis=None)[source]ΒΆ
Roll array elements along a given axis.
LAX-backend implementation of
roll()
.Original docstring below.
Elements that roll beyond the last position are re-introduced at the first.
- Parameters
a (array_like) β Input array.
shift (int or tuple of ints) β The number of places by which elements are shifted. If a tuple, then axis must be a tuple of the same size, and each of the given axes is shifted by the corresponding number. If an int while axis is a tuple of ints, then the same value is used for all given axes.
axis (int or tuple of ints, optional) β Axis or axes along which elements are shifted. By default, the array is flattened before shifting, after which the original shape is restored.
- Returns
res β Output array, with the same shape as a.
- Return type