jax.experimental.checkify.check
Warning
This page was created from a pull request (#9655).
jax.experimental.checkify.check¶
- jax.experimental.checkify.check(pred, msg)[source]¶
Check a predicate, add an error with msg if predicate is False.
This is an effectful operation, and can’t be staged (jitted/scanned/…). Before staging a function with checks,
checkify
it!- Parameters
For example:
>>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> def f(x): ... checkify.check(x!=0, "cannot be zero!") ... return 1/x >>> checked_f = checkify.checkify(f) >>> err, out = jax.jit(checked_f)(0) >>> err.throw() Traceback (most recent call last): ... ValueError: cannot be zero! (check failed at ...)
- Return type