jax.experimental.checkify.checkify
Warning
This page was created from a pull request (#9655).
jax.experimental.checkify.checkify¶
- jax.experimental.checkify.checkify(fun, errors=frozenset({ErrorCategory.USER_CHECK}))[source]¶
Functionalize check calls in fun, and optionally add run-time error checks.
Run-time errors are either user-added
checkify.checkassertions, or automatically added checks like NaN checks, depending on theerrorsargument.The returned function will return an Error object err along with the output of the original function.
err.get()will either returnNone(if no error occurred) or a string containing an error message. This error message will correspond to the first error which occurred.err.throw()will raise a ValueError with the error message if an error occurred.By default only user-added
checkify.checkassertions are enabled. You can enable automatic checks through theerrorsargument.- The automatic check sets which can be enabled, and when an error is generated:
user_checks: acheckify.checkevaluated to False.nan_checks: a floating-point operation generated a NaN value as output.div_checks: a division by zero.index_checks: an index was out-of-bounds.
Multiple categories can be enabled together by creating a Set (eg.
errors={ErrorCategory.NAN, ErrorCategory.OOB}). Multiple sets can be re-combined (eg.errors=float_checks|user_checks)- Parameters
fun (
Callable[…, ~Out]) – Callable which can contain user checks (seecheck).errors (
FrozenSet[ErrorCategory]) – A set of ErrorCategory values which defines the set of enabled checks. By default only explicitchecksare enabled (user_checks). You can also for example enable NAN and DIV errors by passing thefloat_checksset, or for example combine multiple sets through set operations (float_checks | user_checks)
- Return type
- Returns
A function which accepts the same arguments as
funand returns as output a pair where the first element is anErrorvalue, representing the first failedcheck, and the second element is the original output offun.
For example:
>>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> >>> @jax.jit ... def f(x): ... y = jnp.sin(x) ... return x+y >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf) >>> err.throw() Traceback (most recent call last): ... ValueError: nan generated by primitive sin