jax.numpy.dsplit
Warning
This page was created from a pull request (#9655).
jax.numpy.dsplit¶
- jax.numpy.dsplit(ary, indices_or_sections)¶
Split array into multiple sub-arrays along the 3rd axis (depth).
LAX-backend implementation of
dsplit()
.Original docstring below.
Please refer to the split documentation. dsplit is equivalent to split with
axis=2
, the array is always split along the third axis provided the array dimension is greater than or equal to 3.