jax.experimental.host_callback.call
Warning
This page was created from a pull request (#9655).
jax.experimental.host_callback.callΒΆ
- jax.experimental.host_callback.call(callback_func, arg, *, result_shape=None, call_with_device=False)[source]ΒΆ
Make a call to the host, and expect a result.
Experimental: please give feedback, and expect changes!
- Parameters
callback_func (
Callable) β The Python function to invoke on the host ascallback_func(arg). If thecall_with_deviceoptional argument is True, then the invocation also includes thedevicekwarg with the device from which the call originates:callback_func(arg, device=dev). This function must return a pytree of numpy ndarrays.arg β the argument passed to the callback function, can be a pytree of JAX types.
result_shape β a value that describes the expected shape and dtype of the result. This can be a numeric scalar, from which a shape and dtype are obtained, or an object that has
.shapeand.dtypeattributes. If the result of the callback is a pytree, thenresult_shapeshould also be a pytree with the same structure. In particular,result_shapecan be () or None if the function does not have any results. The device code containingcallis compiled with the expected result shape and dtype, and an error will be raised at runtime if the actualcallback_funcinvocation returns a different kind of result.call_with_device β if True then the callback function is invoked with the device from which the call originates as a keyword argument.
- Returns
the result of the
callback_funcinvocation.
For more details see the module documentation.