jax.tree_util.tree_map
Warning
This page was created from a pull request (#9655).
jax.tree_util.tree_mapΒΆ
- jax.tree_util.tree_map(f, tree, *rest, is_leaf=None)[source]ΒΆ
Maps a multi-input function over pytree args to produce a new pytree.
- Parameters
f (
Callable
[β¦,Any
]) β function that takes1 + len(rest)
arguments, to be applied at the corresponding leaves of the pytrees.tree (
Any
) β a pytree to be mapped over, with each leaf providing the first positional argument tof
.*rest β a tuple of pytrees, each of which has the same structure as tree or or has tree as a prefix.
is_leaf (
Optional
[Callable
[[Any
],bool
]]) β an optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.rest (
Any
) β
- Return type
- Returns
A new pytree with the same structure as
tree
but with the value at each leaf given byf(x, *xs)
wherex
is the value at the corresponding leaf intree
andxs
is the tuple of values at corresponding nodes inrest
.