jax.tree_util.tree_unflatten
Warning
This page was created from a pull request (#9655).
jax.tree_util.tree_unflatten¶
- jax.tree_util.tree_unflatten(treedef, leaves)[source]¶
Reconstructs a pytree from the treedef and the leaves.
The inverse of
tree_flatten()
.- Parameters
treedef – the treedef to reconstruct
leaves – the list of leaves to use for reconstruction. The list must match the leaves of the treedef.
- Returns
The reconstructed pytree, containing the
leaves
placed in the structure described bytreedef
.