jax.experimental.host_callback.id_tap
Warning
This page was created from a pull request (#9655).
jax.experimental.host_callback.id_tapΒΆ
- jax.experimental.host_callback.id_tap(tap_func, arg, *, result=None, tap_with_device=False, **kwargs)[source]ΒΆ
Host-callback tap primitive, like identity function with a call to
tap_func.Experimental: please give feedback, and expect changes!
id_tapbehaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime value of the argument.- Parameters
tap_func β tap function to call like
tap_func(arg, transforms), withargas described below and wheretransformsis the sequence of applied JAX transformations in the form(name, params). If the tap_with_device optional argument is True, then the invocation also includes the device from which the value is tapped as a keyword argument:tap_func(arg, transforms, device=dev).arg β the argument passed to the tap function, can be a pytree of JAX types.
result β if given, specifies the return value of
id_tap. This value is not passed to the tap function, and in fact is not sent from the device to the host. If theresultparameter is not specified then the return value ofid_tapisarg.tap_with_device β if True then the tap function is invoked with the device from which the tap originates as a keyword argument.
- Returns
arg, orresultif given.
The order of execution is by data dependency: after all the arguments and the value of
resultif present, are computed and before the returned value is used. At least one of the returned values ofid_tapmust be used in the rest of the computation, or else this operation has no effect.Tapping works even for code executed on accelerators and even for code under JAX transformations.
For more details see the module documentation.