jax.experimental.checkify.check_error

Warning

This page was created from a pull request (#9655).

jax.experimental.checkify.check_error¶

jax.experimental.checkify.check_error(error)[source]¶

Raise an Exception if error represents a failure. Functionalized by checkify.

The semantics of this function are equivalent to:

>>> def check_error(err: Error) -> None:
...   err.throw()  # can raise ValueError

But unlike that implementation, check_error can be functionalized using the checkify transformation.

This function is similar to check but with a different signature: whereas check takes as arguments a boolean predicate and a new error message string, this function takes an Error value as argument. Both check and this function raise a Python Exception on failure (a side-effect), and thus cannot be staged out by jit, pmap, scan, etc. Both also can be functionalized by using checkify.

But unlike check, this function is like a direct inverse of checkify: whereas checkify takes as input a function which can raise a Python Exception and produces a new function without that effect but which produces an Error value as output, this check_error function can accept an Error value as input and can produce the side-effect of raising an Exception. That is, while checkify goes from functionalizable Exception effect to error value, this check_error goes from error value to functionalizable Exception effect.

check_error is useful when you want to turn checks represented by an Error value (produced by functionalizing checks via checkify) back into Python Exceptions.

Parameters

error (Error) – Error to check.

For example, you might want to functionalize part of your program through checkify, stage out your functionalized code through jit, then re-inject your error value outside of the jit:

>>> import jax
>>> from jax.experimental import checkify
>>> def f(x):
...   checkify.check(x>0, "must be positive!")
...   return x
>>> def with_inner_jit(x):
...   checked_f = checkify.checkify(f)
...   # a checkified function can be jitted
...   error, out = jax.jit(checked_f)(x)
...   checkify.check_error(error)
...   return out
>>> _ = with_inner_jit(1)  # no failed check
>>> with_inner_jit(-1)  
Traceback (most recent call last):
  ...
ValueError: must be positive!
>>> # can re-checkify
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
Return type

None