jax.tree_util.tree_structure

Warning

This page was created from a pull request (#9655).

jax.tree_util.tree_structure¶

jax.tree_util.tree_structure(tree, is_leaf=None)[source]¶

Gets the treedef for a pytree.

Parameters

is_leaf (Optional[Callable[[Any], bool]]) –