jax.tree_util.tree_leaves

Warning

This page was created from a pull request (#9655).

jax.tree_util.tree_leaves¶

jax.tree_util.tree_leaves(tree, is_leaf=None)[source]¶

Gets the leaves of a pytree.

Parameters

is_leaf (Optional[Callable[[Any], bool]]) –