jax.tree_util.register_pytree_node
Warning
This page was created from a pull request (#9655).
jax.tree_util.register_pytree_nodeΒΆ
- jax.tree_util.register_pytree_node(nodetype, flatten_func, unflatten_func)[source]ΒΆ
Extends the set of types that are considered internal nodes in pytrees.
See example usage.
- Parameters
nodetype (
Type[~T]) β a Python type to treat as an internal pytree node.flatten_func (
Callable[[~T],Tuple[~_Children, ~_AuxData]]) β a function to be used during flattening, taking a value of typenodetypeand returning a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable auxiliary data to be stored in the treedef and to be passed to theunflatten_func.unflatten_func (
Callable[[~_AuxData, ~_Children], ~T]) β a function taking two arguments: the auxiliary data that was returned byflatten_funcand stored in the treedef, and the unflattened children. The function should return an instance ofnodetype.