jax.tree_util.register_pytree_node_class
Warning
This page was created from a pull request (#9655).
jax.tree_util.register_pytree_node_class¶
- jax.tree_util.register_pytree_node_class(cls)[source]¶
Extends the set of types that are considered internal nodes in pytrees.
This function is a thin wrapper around
register_pytree_node
, and provides a class-oriented interface:@register_pytree_node_class class Special: def __init__(self, x, y): self.x = x self.y = y def tree_flatten(self): return ((self.x, self.y), None) @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children)