jax.lax.reshape
Warning
This page was created from a pull request (#9655).
jax.lax.reshape¶
- jax.lax.reshape(operand, new_sizes, dimensions=None)[source]¶
Wraps XLA’s Reshape operator.
For inserting/removing dimensions of size 1, prefer using
lax.squeeze
/lax.expand_dims
. These preserve information about axis identity that may be useful for advanced transformation rules.- Parameters
operand (
Any
) – array to be reshaped.new_sizes (
Sequence
[Union
[int
,Any
]]) – sequence of integers specifying the resulting shape. The size of the final array must match the size of the input.dimensions (
Optional
[Sequence
[int
]]) – optional sequence of integers specifying the permutation order of the input shape. If specified, the length must matchoperand.shape
.
- Returns
reshaped array.
- Return type
out
Examples
Simple reshaping from one to two dimensions:
>>> x = jnp.arange(6) >>> y = reshape(x, (2, 3)) >>> y DeviceArray([[0, 1, 2], [3, 4, 5]], dtype=int32)
Reshaping back to one dimension:
>>> reshape(y, (6,)) DeviceArray([0, 1, 2, 3, 4, 5], dtype=int32)
Reshaping to one dimension with permutation of dimensions:
>>> reshape(y, (6,), (1, 0)) DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)