jax.disable_jit
Warning
This page was created from a pull request (#9655).
jax.disable_jit¶
- jax.disable_jit()[source]¶
Context manager that disables
jit()
behavior under its dynamic context.For debugging it is useful to have a mechanism that disables
jit()
everywhere in a dynamic context.Values that have a data dependence on the arguments to a jitted function are traced and abstracted. For example, an abstract value may be a
ShapedArray
instance, representing the set of all possible arrays with a given shape and dtype, but not representing one concrete array with specific values. You might notice those if you use a benign side-effecting operation in a jitted function, like a print:>>> import jax >>> >>> @jax.jit ... def f(x): ... y = x * 2 ... print("Value of y is", y) ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=0/1)> [5 7 9]
Here
y
has been abstracted byjit()
to aShapedArray
, which represents an array with a fixed shape and type but an arbitrary value. The value ofy
is also traced. If we want to see a concrete value while debugging, and avoid the tracer too, we can use thedisable_jit()
context manager:>>> import jax >>> >>> with jax.disable_jit(): ... print(f(jax.numpy.array([1, 2, 3]))) ... Value of y is [2 4 6] [5 7 9]