jax.numpy.array_split
Warning
This page was created from a pull request (#9655).
jax.numpy.array_split¶
- jax.numpy.array_split(ary, indices_or_sections, axis=0)[source]¶
Split an array into multiple sub-arrays.
LAX-backend implementation of
array_split()
.Original docstring below.
Please refer to the
split
documentation. The only difference between these functions is thatarray_split
allows indices_or_sections to be an integer that does not equally divide the axis. For an array of length l that should be split into n sections, it returns l % n sub-arrays of size l//n + 1 and the rest of size l//n.- Parameters
axis (
int
) –