jax.tree_util.tree_map
Warning
This page was created from a pull request (#9655).
jax.tree_util.tree_mapΒΆ
- jax.tree_util.tree_map(f, tree, *rest, is_leaf=None)[source]ΒΆ
Maps a multi-input function over pytree args to produce a new pytree.
- Parameters
f (
Callable[β¦,Any]) β function that takes1 + len(rest)arguments, to be applied at the corresponding leaves of the pytrees.tree (
Any) β a pytree to be mapped over, with each leaf providing the first positional argument tof.*rest β a tuple of pytrees, each of which has the same structure as tree or or has tree as a prefix.
is_leaf (
Optional[Callable[[Any],bool]]) β an optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.rest (
Any) β
- Return type
- Returns
A new pytree with the same structure as
treebut with the value at each leaf given byf(x, *xs)wherexis the value at the corresponding leaf intreeandxsis the tuple of values at corresponding nodes inrest.