jax.profiler.annotate_function
Warning
This page was created from a pull request (#9655).
jax.profiler.annotate_functionΒΆ
- jax.profiler.annotate_function(func, name=None, **decorator_kwargs)[source]ΒΆ
Decorator that generates a trace event for the execution of a function.
For example:
>>> @jax.profiler.annotate_function ... def f(x): ... return jnp.dot(x, x.T).block_until_ready() >>> >>> result = f(jnp.ones((1000, 1000)))
This will cause an βfβ event to show up on the trace timeline if the function execution occurs while the process is being traced by TensorBoard.
Arguments can be passed to the decorator via
functools.partial()
.>>> from functools import partial
>>> @partial(jax.profiler.annotate_function, name="event_name") ... def f(x): ... return jnp.dot(x, x.T).block_until_ready()
>>> result = f(jnp.ones((1000, 1000)))