jax.experimental.checkify.check

Warning

This page was created from a pull request (#9655).

jax.experimental.checkify.check¶

jax.experimental.checkify.check(pred, msg)[source]¶

Check a predicate, add an error with msg if predicate is False.

This is an effectful operation, and can’t be staged (jitted/scanned/…). Before staging a function with checks, checkify it!

Parameters
  • pred (Union[bool, Tracer]) – if False, an error is added.

  • msg (str) – error message if error is added.

For example:

>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
...   checkify.check(x!=0, "cannot be zero!")
...   return 1/x
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(0)
>>> err.throw()  
Traceback (most recent call last):
  ...
ValueError: cannot be zero! (check failed at ...)
Return type

None