{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "O-SkdlPxvETZ" }, "source": [ "# Just In Time Compilation with JAX\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/02-jitting.ipynb)\n", "\n", "*Authors: Rosalia Schneider & Vladimir Mikulik*\n", "\n", "In this section, we will further explore how JAX works, and how we can make it performant.\n", "We will discuss the `jax.jit()` transform, which will perform *Just In Time* (JIT) compilation\n", "of a JAX Python function so it can be executed efficiently in XLA.\n", "\n", "## How JAX transforms work\n", "\n", "In the previous section, we discussed that JAX allows us to transform Python functions. This is done by first converting the Python function into a simple intermediate language called jaxpr. The transformations then work on the jaxpr representation. \n", "\n", "We can show a representation of the jaxpr of a function by using `jax.make_jaxpr`:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "P9Xj77Wx3Z2P", "outputId": "5a0597eb-86c9-4762-ce10-2811debbc732" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{ lambda ; a:f32[]. let\n", " b:f32[] = log a\n", " c:f32[] = log 2.0\n", " d:f32[] = div b c\n", " in (d,) }\n" ] } ], "source": [ "import jax\n", "import jax.numpy as jnp\n", "\n", "global_list = []\n", "\n", "def log2(x):\n", " global_list.append(x)\n", " ln_x = jnp.log(x)\n", " ln_2 = jnp.log(2.0)\n", " return ln_x / ln_2\n", "\n", "print(jax.make_jaxpr(log2)(3.0))" ] }, { "cell_type": "markdown", "metadata": { "id": "jiDsT7y0RwIp" }, "source": [ "The [Understanding Jaxprs](https://jax.readthedocs.io/en/latest/jaxpr.html) section of the documentation provides more information on the meaning of the above output.\n", "\n", "Importantly, note how the jaxpr does not capture the side-effect of the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX is designed to understand side-effect-free (a.k.a. functionally pure) code. If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).\n", "\n", "Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour once converted to jaxpr. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JAX-transformed function to run once (during the first call), and never again. This is because of the way that JAX generates jaxpr, using a process called 'tracing'.\n", "\n", "When tracing, JAX wraps each argument by a *tracer* object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself.\n", "\n", "Note: the Python `print()` function is not pure: the text output is a side-effect of the function. Therefore, any `print()` calls will only happen during tracing, and will not appear in the jaxpr:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "JxV2p7e2RawC", "outputId": "9dfe8a56-e553-4640-a04e-5405aea7832d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "printed x: Tracedwith\n", "{ lambda ; a:f32[]. let\n", " b:f32[] = log a\n", " c:f32[] = log 2.0\n", " d:f32[] = div b c\n", " in (d,) }\n" ] } ], "source": [ "def log2_with_print(x):\n", " print(\"printed x:\", x)\n", " ln_x = jnp.log(x)\n", " ln_2 = jnp.log(2.0)\n", " return ln_x / ln_2\n", "\n", "print(jax.make_jaxpr(log2_with_print)(3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "f6W_YYwRRwGp" }, "source": [ "See how the printed `x` is a `Traced` object? That's the JAX internals at work.\n", "\n", "The fact that the Python code runs at least once is strictly an implementation detail, and so shouldn't be relied upon. However, it's useful to understand as you can use it when debugging to print out intermediate values of a computation." ] }, { "cell_type": "markdown", "metadata": { "id": "PgVqi6NlRdWZ" }, "source": [ "A key thing to understand is that jaxpr captures the function as executed on the parameters given to it. For example, if we have a conditional, jaxpr will only know about the branch we take:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "hn0CuphEZKZm", "outputId": "99dae727-d2be-4577-831c-e1e14af5890a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{ lambda ; a:i32[3]. let in (a,) }\n" ] } ], "source": [ "def log2_if_rank_2(x):\n", " if x.ndim == 2:\n", " ln_x = jnp.log(x)\n", " ln_2 = jnp.log(2.0)\n", " return ln_x / ln_2\n", " else:\n", " return x\n", "\n", "print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))" ] }, { "cell_type": "markdown", "metadata": { "id": "Qp3WhqaqvHyD" }, "source": [ "## JIT compiling a function\n", "\n", "As explained before, JAX enables operations to execute on CPU/GPU/TPU using the same code.\n", "Let's look at an example of computing a *Scaled Exponential Linear Unit*\n", "([SELU](https://proceedings.neurips.cc/paper/6698-self-normalizing-neural-networks.pdf)), an\n", "operation commonly used in deep learning:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "JAXFYtlRvD6p", "outputId": "e94d7dc2-a9a1-4ac2-fd3f-152e3f6d141b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100 loops, best of 5: 2.05 ms per loop\n" ] } ], "source": [ "import jax\n", "import jax.numpy as jnp\n", "\n", "def selu(x, alpha=1.67, lambda_=1.05):\n", " return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n", "\n", "x = jnp.arange(1000000)\n", "%timeit selu(x).block_until_ready()" ] }, { "cell_type": "markdown", "metadata": { "id": "ecN5lEXe6ncy" }, "source": [ "The code above is sending one operation at a time to the accelerator. This limits the ability of the XLA compiler to optimize our functions.\n", "\n", "Naturally, what we want to do is give the XLA compiler as much code as possible, so it can fully optimize it. For this purpose, JAX provides the `jax.jit` transformation, which will JIT compile a JAX-compatible function. The example below shows how to use JIT to speed up the previous function." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "nJVEwPcH6bQX", "outputId": "289eb2f7-a5ce-4cec-f652-5c4e5b0b86cf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10000 loops, best of 5: 150 µs per loop\n" ] } ], "source": [ "selu_jit = jax.jit(selu)\n", "\n", "# Warm up\n", "selu_jit(x).block_until_ready()\n", "\n", "%timeit selu_jit(x).block_until_ready()" ] }, { "cell_type": "markdown", "metadata": { "id": "hMNKi1mYXQg5" }, "source": [ "Here's what just happened:\n", "\n", "1) We defined `selu_jit` as the compiled version of `selu`.\n", "\n", "2) We ran `selu_jit` once on `x`. This is where JAX does its tracing -- it needs to have some inputs to wrap in tracers, after all. The jaxpr is then compiled using XLA into very efficient code optimized for your GPU or TPU. Subsequent calls to `selu_jit` will now use that code, skipping our old Python implementation entirely. \n", "\n", "(If we didn't include the warm-up call separately, everything would still work, but then the compilation time would be included in the benchmark. It would still be faster, because we run many loops in the benchmark, but it wouldn't be a fair comparison.)\n", "\n", "3) We timed the execution speed of the compiled version. (Note the use of `block_until_ready()`, which is required due to JAX's [Asynchronous execution](https://jax.readthedocs.io/en/latest/async_dispatch.html) model)." ] }, { "cell_type": "markdown", "metadata": { "id": "DRJ6R6-d9Q_U" }, "source": [ "## Why can't we just JIT everything?\n", "\n", "After going through the example above, you might be wondering whether we should simply apply `jax.jit` to every function. To understand why this is not the case, and when we should/shouldn't apply `jit`, let's first check some cases where JIT doesn't work." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "GO1Mwd_3_W6g", "outputId": "a6fcf6d1-7bd6-4bb7-99c3-2a5a827183e2", "tags": [ "raises-exception" ] }, "outputs": [ { "ename": "ConcretizationTypeError", "evalue": "ignored", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mf_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mf_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Should raise an error.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mcache_miss\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mflat_fun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 419\u001b[0;31m donated_invars=donated_invars, inline=inline)\n\u001b[0m\u001b[1;32m 420\u001b[0m \u001b[0mout_pytree_def\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1631\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1632\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcall_bind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1633\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1622\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1623\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1624\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_todos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_trace_todo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess\u001b[0;34m(self, trace, fun, tracers, params)\u001b[0m\n\u001b[1;32m 1634\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1635\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1636\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess_call\u001b[0;34m(self, primitive, f, tracers, params)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 627\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 628\u001b[0m \u001b[0mprocess_map\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_call_impl\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 687\u001b[0m compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0;32m--> 688\u001b[0;31m *unsafe_map(arg_spec, args))\n\u001b[0m\u001b[1;32m 689\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mmemoized_fun\u001b[0;34m(fun, *args)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_callable_uncached\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 759\u001b[0m return lower_xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0;32m--> 760\u001b[0;31m *arg_specs).compile().unsafe_call\n\u001b[0m\u001b[1;32m 761\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36mlower_xla_callable\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 771\u001b[0m jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(\n\u001b[0;32m--> 772\u001b[0;31m fun, abstract_args, pe.debug_info_final(fun, \"jit\"))\n\u001b[0m\u001b[1;32m 773\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTracer\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_final\u001b[0;34m(fun, in_avals, debug_info)\u001b[0m\n\u001b[1;32m 1541\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1542\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_dynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1543\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals)\u001b[0m\n\u001b[1;32m 1519\u001b[0m \u001b[0min_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_arg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0min_tracers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0mout_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mf\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36m__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 548\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__nonzero__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nonzero\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 549\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0m__bool__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_bool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 550\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__int__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_int\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36merror\u001b[0;34m(self, arg)\u001b[0m\n\u001b[1;32m 999\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1000\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mConcretizationTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfname_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1001\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mUnfilteredStackTrace\u001b[0m: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the bool function. \nWhile tracing the function f at :3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------", "\nThe above exception was the direct cause of the following exception:\n", "\u001b[0;31mConcretizationTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mf_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mf_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Should raise an error.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36mf\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mConcretizationTypeError\u001b[0m: Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the bool function. \nWhile tracing the function f at :3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError" ] } ], "source": [ "# Condition on value of x.\n", "\n", "def f(x):\n", " if x > 0:\n", " return x\n", " else:\n", " return 2 * x\n", "\n", "f_jit = jax.jit(f)\n", "f_jit(10) # Should raise an error. " ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "LHlipkIMFUhi", "outputId": "54935882-a180-45c0-ad03-9dfb5e3baa97", "tags": [ "raises-exception" ] }, "outputs": [ { "ename": "ConcretizationTypeError", "evalue": "ignored", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mg_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mg_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Should raise an error.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mcache_miss\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mflat_fun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 419\u001b[0;31m donated_invars=donated_invars, inline=inline)\n\u001b[0m\u001b[1;32m 420\u001b[0m \u001b[0mout_pytree_def\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1631\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1632\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcall_bind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1633\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1622\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1623\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1624\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_todos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_trace_todo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess\u001b[0;34m(self, trace, fun, tracers, params)\u001b[0m\n\u001b[1;32m 1634\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1635\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1636\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess_call\u001b[0;34m(self, primitive, f, tracers, params)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 627\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 628\u001b[0m \u001b[0mprocess_map\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_call_impl\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 687\u001b[0m compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0;32m--> 688\u001b[0;31m *unsafe_map(arg_spec, args))\n\u001b[0m\u001b[1;32m 689\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mmemoized_fun\u001b[0;34m(fun, *args)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_callable_uncached\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 759\u001b[0m return lower_xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0;32m--> 760\u001b[0;31m *arg_specs).compile().unsafe_call\n\u001b[0m\u001b[1;32m 761\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36mlower_xla_callable\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 771\u001b[0m jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(\n\u001b[0;32m--> 772\u001b[0;31m fun, abstract_args, pe.debug_info_final(fun, \"jit\"))\n\u001b[0m\u001b[1;32m 773\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTracer\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_final\u001b[0;34m(fun, in_avals, debug_info)\u001b[0m\n\u001b[1;32m 1541\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1542\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_dynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1543\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals)\u001b[0m\n\u001b[1;32m 1519\u001b[0m \u001b[0min_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_arg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0min_tracers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0mout_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mg\u001b[0;34m(x, n)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36m__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 548\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__nonzero__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nonzero\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 549\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0m__bool__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_bool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 550\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__int__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_int\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36merror\u001b[0;34m(self, arg)\u001b[0m\n\u001b[1;32m 999\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1000\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mConcretizationTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfname_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1001\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mUnfilteredStackTrace\u001b[0m: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the bool function. \nWhile tracing the function g at :3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'n'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------", "\nThe above exception was the direct cause of the following exception:\n", "\u001b[0;31mConcretizationTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mg_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mg_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Should raise an error.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36mg\u001b[0;34m(x, n)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mConcretizationTypeError\u001b[0m: Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the bool function. \nWhile tracing the function g at :3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'n'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError" ] } ], "source": [ "# While loop conditioned on x and n.\n", "\n", "def g(x, n):\n", " i = 0\n", " while i < n:\n", " i += 1\n", " return x + i\n", "\n", "g_jit = jax.jit(g)\n", "g_jit(10, 20) # Should raise an error. " ] }, { "cell_type": "markdown", "metadata": { "id": "isz2U_XX_wH2" }, "source": [ "The problem is that we tried to condition on the *value* of an input to the function being jitted. The reason we can't do this is related to the fact mentioned above that jaxpr depends on the actual values used to trace it. \n", "\n", "The more specific information about the values we use in the trace, the more we can use standard Python control flow to express ourselves. However, being too specific means we can't reuse the same traced function for other values. JAX solves this by tracing at different levels of abstraction for different purposes.\n", "\n", "For `jax.jit`, the default level is `ShapedArray` -- that is, each tracer has a concrete shape (which we're allowed to condition on), but no concrete value. This allows the compiled function to work on all possible inputs with the same shape -- the standard use case in machine learning. However, because the tracers have no concrete value, if we attempt to condition on one, we get the error above.\n", "\n", "In `jax.grad`, the constraints are more relaxed, so you can do more. If you compose several transformations, however, you must satisfy the constraints of the most strict one. So, if you `jit(grad(f))`, `f` mustn't condition on value. For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).\n", "\n", "One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special [control flow operators](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) like `jax.lax.cond`. However, sometimes that is impossible. In that case, you can consider jitting only part of the function. For example, if the most computationally expensive part of the function is inside the loop, we can JIT just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "OeR8hF-NHAML", "outputId": "d47fd6b2-8bbd-4939-a794-0b80183d3179" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(30, dtype=int32, weak_type=True)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# While loop conditioned on x and n with a jitted body.\n", "\n", "@jax.jit\n", "def loop_body(prev_i):\n", " return prev_i + 1\n", "\n", "def g_inner_jitted(x, n):\n", " i = 0\n", " while i < n:\n", " i = loop_body(i)\n", " return x + i\n", "\n", "g_inner_jitted(10, 20)" ] }, { "cell_type": "markdown", "metadata": { "id": "5XUT2acoHBz-" }, "source": [ "If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified input. It is only a good strategy if the function is guaranteed to get limited different values." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "2yQmQTDNAenY", "outputId": "c48f07b8-c3f9-4d2a-9dfd-663838a52511" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10\n" ] } ], "source": [ "f_jit_correct = jax.jit(f, static_argnums=0)\n", "print(f_jit_correct(10))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "R4SXUEu-M-u1", "outputId": "9e712e14-4e81-4744-dcf2-a10f470d9121" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "30\n" ] } ], "source": [ "g_jit_correct = jax.jit(g, static_argnums=1)\n", "print(g_jit_correct(10, 20))" ] }, { "cell_type": "markdown", "metadata": { "id": "LczjIBt2X2Ms" }, "source": [ "## When to use JIT\n", "\n", "In many of the the examples above, jitting is not worth it:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "uMOqsNnqYApD", "outputId": "2d6c5122-43ad-4257-e56b-e77c889131c2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "g jitted:\n", "The slowest run took 13.54 times longer than the fastest. This could mean that an intermediate result is being cached.\n", "1000 loops, best of 5: 229 µs per loop\n", "g:\n", "The slowest run took 11.72 times longer than the fastest. This could mean that an intermediate result is being cached.\n", "1000000 loops, best of 5: 1.2 µs per loop\n" ] } ], "source": [ "print(\"g jitted:\")\n", "%timeit g_jit_correct(10, 20).block_until_ready()\n", "\n", "print(\"g:\")\n", "%timeit g(10, 20)" ] }, { "cell_type": "markdown", "metadata": { "id": "cZmGYq80YP0j" }, "source": [ "This is because `jax.jit` introduces some overhead itself. Therefore, it usually only saves time if the compiled function is complex and you will run it numerous times. Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations.\n", "\n", "Generally, you want to jit the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise." ] }, { "cell_type": "markdown", "metadata": { "id": "hJMjUlRcIzVS" }, "source": [ "## Caching\n", "\n", "It's important to understand the caching behaviour of `jax.jit`.\n", "\n", "Suppose I define `f = jax.jit(g)`. When I first invoke `f`, it will get compiled, and the resulting XLA code will get cached. Subsequent calls of `f` will reuse the cached code. This is how `jax.jit` makes up for the up-front cost of compilation.\n", "\n", "If I specify `static_argnums`, then the cached code will be used only for the same values of arguments labelled as static. If any of them change, recompilation occurs. If there are many values, then your program might spend more time compiling than it would have executing ops one-by-one.\n", "\n", "Avoid calling `jax.jit` inside loops. For most cases, JAX will be able to use the compiled, cached function in subsequent calls to `jax.jit`. However, because the cache relies on the hash of the function, it becomes problematic when equivalent functions are redefined. This will cause unnecessary compilation each time in the loop:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "6MDSXCfmSZVZ", "outputId": "a035d0b7-6a4d-4a9e-c6b4-7521970829fc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "jit called in a loop with partials:\n", "1 loop, best of 5: 192 ms per loop\n", "jit called in a loop with lambdas:\n", "10 loops, best of 5: 199 ms per loop\n", "jit called in a loop with caching:\n", "10 loops, best of 5: 21.6 ms per loop\n" ] } ], "source": [ "from functools import partial\n", "\n", "def unjitted_loop_body(prev_i):\n", " return prev_i + 1\n", "\n", "def g_inner_jitted_partial(x, n):\n", " i = 0\n", " while i < n:\n", " # Don't do this! each time the partial returns\n", " # a function with different hash\n", " i = jax.jit(partial(unjitted_loop_body))(i)\n", " return x + i\n", "\n", "def g_inner_jitted_lambda(x, n):\n", " i = 0\n", " while i < n:\n", " # Don't do this!, lambda will also return\n", " # a function with a different hash\n", " i = jax.jit(lambda x: unjitted_loop_body(x))(i)\n", " return x + i\n", "\n", "def g_inner_jitted_normal(x, n):\n", " i = 0\n", " while i < n:\n", " # this is OK, since JAX can find the\n", " # cached, compiled function\n", " i = jax.jit(unjitted_loop_body)(i)\n", " return x + i\n", "\n", "print(\"jit called in a loop with partials:\")\n", "%timeit g_inner_jitted_partial(10, 20).block_until_ready()\n", "\n", "print(\"jit called in a loop with lambdas:\")\n", "%timeit g_inner_jitted_lambda(10, 20).block_until_ready()\n", "\n", "print(\"jit called in a loop with caching:\")\n", "%timeit g_inner_jitted_normal(10, 20).block_until_ready()" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "Jitting functions in JAX", "provenance": [] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }