jax.profiler.start_trace
Warning
This page was created from a pull request (#9655).
jax.profiler.start_trace¶
- jax.profiler.start_trace(log_dir)[source]¶
Starts a profiler trace.
The trace will capture CPU, GPU, and/or TPU activity, including Python functions and JAX on-device operations. Use
stop_trace()
to end the trace and save the results tolog_dir
.The resulting trace can be viewed with TensorBoard. Note that TensorBoard doesn’t need to be running when collecting the trace.
Only once trace may be collected a time. A RuntimeError will be raised if
start_trace()
is called while another trace is running.- Parameters
log_dir – The directory to save the profiler trace to (usually the TensorBoard log directory).