{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "tCOWitsAS1EE" }, "source": [ "# Parallel Evaluation in JAX\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/06-parallelism.ipynb)\n", "\n", "*Authors: Vladimir Mikulik & Roman Ring*\n", "\n", "In this section we will discuss the facilities built into JAX for single-program, multiple-data (SPMD) code.\n", "\n", "SPMD refers to a parallelism technique where the same computation (e.g., the forward pass of a neural net) is run on different input data (e.g., different inputs in a batch) in parallel on different devices (e.g., several TPUs).\n", "\n", "Conceptually, this is not very different from vectorisation, where the same operations occur in parallel in different parts of memory on the same device. We have already seen that vectorisation is supported in JAX as a program transformation, `jax.vmap`. JAX supports device parallelism analogously, using `jax.pmap` to transform a function written for one device into a function that runs in parallel on multiple devices. This colab will teach you all about it." ] }, { "cell_type": "markdown", "metadata": { "id": "7mCgBzix2fd3" }, "source": [ "## Colab TPU Setup\n", "\n", "If you're running this code in Google Colab, be sure to choose *Runtime*→*Change Runtime Type* and choose **TPU** from the Hardware Accelerator menu.\n", "\n", "Once this is done, you can run the following to set up the Colab TPU for use with JAX:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "hn7HtC2QS92b" }, "outputs": [], "source": [ "import jax.tools.colab_tpu\n", "jax.tools.colab_tpu.setup_tpu()" ] }, { "cell_type": "markdown", "metadata": { "id": "gN6VbcdRTcdE" }, "source": [ "Next run the following to see the TPU devices you have available:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "tqbpCcqY3Cn7", "outputId": "1fb88cf7-35f7-4565-f370-51586213b988" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, host_id=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, host_id=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, host_id=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, host_id=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, host_id=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, host_id=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, host_id=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, host_id=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 2, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "import jax\n", "jax.devices()" ] }, { "cell_type": "markdown", "metadata": { "id": "4_EDa0Dlgtf8" }, "source": [ "## The basics\n", "\n", "The most basic use of `jax.pmap` is completely analogous to `jax.vmap`, so let's return to the convolution example from the [Vectorisation notebook](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/03-vectorization.ipynb)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "IIQKBr-CgtD2", "outputId": "6e7f8755-fdfd-4cf9-e2b5-a10c5a870dd4" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([11., 20., 29.], dtype=float32)" ] }, "execution_count": 5, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import jax.numpy as jnp\n", "\n", "x = np.arange(5)\n", "w = np.array([2., 3., 4.])\n", "\n", "def convolve(x, w):\n", " output = []\n", " for i in range(1, len(x)-1):\n", " output.append(jnp.dot(x[i-1:i+2], w))\n", " return jnp.array(output)\n", "\n", "convolve(x, w)" ] }, { "cell_type": "markdown", "metadata": { "id": "lqxz9NNJOQ9Z" }, "source": [ "Now, let's convert our `convolve` function into one that runs on entire batches of data. In anticipation of spreading the batch across several devices, we'll make the batch size equal to the number of devices:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "ll-hEa0jihzx", "outputId": "788be05a-10d4-4a05-8d9d-49d0083541ab" }, "outputs": [ { "data": { "text/plain": [ "array([[ 0, 1, 2, 3, 4],\n", " [ 5, 6, 7, 8, 9],\n", " [10, 11, 12, 13, 14],\n", " [15, 16, 17, 18, 19],\n", " [20, 21, 22, 23, 24],\n", " [25, 26, 27, 28, 29],\n", " [30, 31, 32, 33, 34],\n", " [35, 36, 37, 38, 39]])" ] }, "execution_count": 6, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "n_devices = jax.local_device_count() \n", "xs = np.arange(5 * n_devices).reshape(-1, 5)\n", "ws = np.stack([w] * n_devices)\n", "\n", "xs" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "mi-nysDWYbn4", "outputId": "2d115fc3-52f5-4a68-c3a7-115111a83657" }, "outputs": [ { "data": { "text/plain": [ "array([[2., 3., 4.],\n", " [2., 3., 4.],\n", " [2., 3., 4.],\n", " [2., 3., 4.],\n", " [2., 3., 4.],\n", " [2., 3., 4.],\n", " [2., 3., 4.],\n", " [2., 3., 4.]])" ] }, "execution_count": 7, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "ws" ] }, { "cell_type": "markdown", "metadata": { "id": "8kseIB09YWJw" }, "source": [ "As before, we can vectorise using `jax.vmap`:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "TNb9HsFXYVOI", "outputId": "2e60e07a-6687-49ab-a455-60d2ec484363" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[ 11., 20., 29.],\n", " [ 56., 65., 74.],\n", " [101., 110., 119.],\n", " [146., 155., 164.],\n", " [191., 200., 209.],\n", " [236., 245., 254.],\n", " [281., 290., 299.],\n", " [326., 335., 344.]], dtype=float32)" ] }, "execution_count": 8, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "jax.vmap(convolve)(xs, ws)" ] }, { "cell_type": "markdown", "metadata": { "id": "TDF1vzt_5GMC" }, "source": [ "To spread out the computation across multiple devices, just replace `jax.vmap` with `jax.pmap`:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "KWoextrails4", "outputId": "bad1fbb7-226a-4538-e442-20ce0c1c8fad" }, "outputs": [ { "data": { "text/plain": [ "ShardedDeviceArray([[ 11., 20., 29.],\n", " [ 56., 65., 74.],\n", " [101., 110., 119.],\n", " [146., 155., 164.],\n", " [191., 200., 209.],\n", " [236., 245., 254.],\n", " [281., 290., 299.],\n", " [326., 335., 344.]], dtype=float32)" ] }, "execution_count": 9, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "jax.pmap(convolve)(xs, ws)" ] }, { "cell_type": "markdown", "metadata": { "id": "E69cVxQPksxe" }, "source": [ "Note that the parallelized `convolve` returns a `ShardedDeviceArray`. That is because the elements of this array are sharded across all of the devices used in the parallelism. If we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "P9dUyk-ciquy", "outputId": "99ea4c6e-cff7-4611-e9e5-bf016fa9716c" }, "outputs": [ { "data": { "text/plain": [ "ShardedDeviceArray([[ 78., 138., 198.],\n", " [ 1188., 1383., 1578.],\n", " [ 3648., 3978., 4308.],\n", " [ 7458., 7923., 8388.],\n", " [12618., 13218., 13818.],\n", " [19128., 19863., 20598.],\n", " [26988., 27858., 28728.],\n", " [36198., 37203., 38208.]], dtype=float32)" ] }, "execution_count": 11, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "jax.pmap(convolve)(xs, jax.pmap(convolve)(xs, ws))" ] }, { "cell_type": "markdown", "metadata": { "id": "iuHqht-OYqca" }, "source": [ "The outputs of the inner `jax.pmap(convolve)` never left their devices when being fed into the outer `jax.pmap(convolve)`." ] }, { "cell_type": "markdown", "metadata": { "id": "vEFAJXN2q3dV" }, "source": [ "## Specifying `in_axes`\n", "\n", "Like with `vmap`, we can use `in_axes` to specify whether an argument to the parallelized function should be broadcast (`None`), or whether it should be split along a given axis. Note, however, that unlike `vmap`, only the leading axis (`0`) is supported by `pmap` at the time of writing this guide." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "6Es5WVuRlXnB", "outputId": "7e9612ae-d6e0-4d79-a228-f0403fcf8237" }, "outputs": [ { "data": { "text/plain": [ "ShardedDeviceArray([[ 11., 20., 29.],\n", " [ 56., 65., 74.],\n", " [101., 110., 119.],\n", " [146., 155., 164.],\n", " [191., 200., 209.],\n", " [236., 245., 254.],\n", " [281., 290., 299.],\n", " [326., 335., 344.]], dtype=float32)" ] }, "execution_count": 12, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "jax.pmap(convolve, in_axes=(0, None))(xs, w)" ] }, { "cell_type": "markdown", "metadata": { "id": "EoN6drHDOlk4" }, "source": [ "Notice how we get equivalent output to what we observe above with `jax.pmap(convolve)(xs, ws)`, where we manually replicated `w` when creating `ws`. Here, it is replicated via broadcasting, by specifying it as `None` in `in_axes`." ] }, { "cell_type": "markdown", "metadata": { "id": "rRE8STSU5cjx" }, "source": [ "Keep in mind that when calling the transformed function, the size of the specified axis in arguments must not exceed the number of devices available to the host." ] }, { "cell_type": "markdown", "metadata": { "id": "0lZnqImd7G6U" }, "source": [ "## `pmap` and `jit`\n", "\n", "`jax.pmap` JIT-compiles the function given to it as part of its operation, so there is no need to additionally `jax.jit` it." ] }, { "cell_type": "markdown", "metadata": { "id": "1jZqk_2AwO4y" }, "source": [ "## Communication between devices\n", "\n", "The above is enough to perform simple parallel operations, e.g. batching a simple MLP forward pass across several devices. However, sometimes we need to pass information between the devices. For example, perhaps we are interested in normalizing the output of each device so they sum to 1.\n", "For that, we can use special [collective ops](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) (such as the `jax.lax.p*` ops `psum`, `pmean`, `pmax`, ...). In order to use the collective ops we must specify the name of the `pmap`-ed axis through `axis_name` argument, and then refer to it when calling the op. Here's how to do that:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "0nCxGwqmtd3w", "outputId": "6f9c93b0-51ed-40c5-ca5a-eacbaf40e686" }, "outputs": [ { "data": { "text/plain": [ "ShardedDeviceArray([[0.00816024, 0.01408451, 0.019437 ],\n", " [0.04154303, 0.04577465, 0.04959785],\n", " [0.07492582, 0.07746479, 0.07975871],\n", " [0.10830861, 0.10915492, 0.10991956],\n", " [0.14169139, 0.14084506, 0.14008042],\n", " [0.17507419, 0.17253521, 0.17024128],\n", " [0.20845698, 0.20422535, 0.20040214],\n", " [0.24183977, 0.23591548, 0.23056298]], dtype=float32)" ] }, "execution_count": 13, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "def normalized_convolution(x, w):\n", " output = []\n", " for i in range(1, len(x)-1):\n", " output.append(jnp.dot(x[i-1:i+2], w))\n", " output = jnp.array(output)\n", " return output / jax.lax.psum(output, axis_name='p')\n", "\n", "jax.pmap(normalized_convolution, axis_name='p')(xs, ws)" ] }, { "cell_type": "markdown", "metadata": { "id": "9ENYsJS42YVK" }, "source": [ "The `axis_name` is just a string label that allows collective operations like `jax.lax.psum` to refer to the axis bound by `jax.pmap`. It can be named anything you want -- in this case, `p`. This name is essentially invisible to anything but those functions, and those functions use it to know which axis to communicate across.\n", "\n", "`jax.vmap` also supports `axis_name`, which allows `jax.lax.p*` operations to be used in the vectorisation context in the same way they would be used in a `jax.pmap`:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "nT61xAYJUqCW", "outputId": "e8831025-78a6-4a2b-a60a-3c77b35214ef" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[0.00816024, 0.01408451, 0.019437 ],\n", " [0.04154303, 0.04577465, 0.04959785],\n", " [0.07492582, 0.07746479, 0.07975871],\n", " [0.10830861, 0.10915492, 0.10991956],\n", " [0.14169139, 0.14084506, 0.14008042],\n", " [0.17507419, 0.17253521, 0.17024128],\n", " [0.20845698, 0.20422535, 0.20040214],\n", " [0.24183977, 0.23591548, 0.23056298]], dtype=float32)" ] }, "execution_count": 14, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "jax.vmap(normalized_convolution, axis_name='p')(xs, ws)" ] }, { "cell_type": "markdown", "metadata": { "id": "JSK-9dbWWV2O" }, "source": [ "Note that `normalized_convolution` will no longer work without being transformed by `jax.pmap` or `jax.vmap`, because `jax.lax.psum` expects there to be a named axis (`'p'`, in this case), and those two transformations are the only way to bind one.\n", "\n", "## Nesting `jax.pmap` and `jax.vmap`\n", "\n", "The reason we specify `axis_name` as a string is so we can use collective operations when nesting `jax.pmap` and `jax.vmap`. For example:\n", "\n", "```python\n", "jax.vmap(jax.pmap(f, axis_name='i'), axis_name='j')\n", "```\n", "\n", "A `jax.lax.psum(..., axis_name='i')` in `f` would refer only to the pmapped axis, since they share the `axis_name`. \n", "\n", "In general, `jax.pmap` and `jax.vmap` can be nested in any order, and with themselves (so you can have a `pmap` within another `pmap`, for instance)." ] }, { "cell_type": "markdown", "metadata": { "id": "WzQHxnHkCxej" }, "source": [ "## Example\n", "\n", "Here's an example of a regression training loop with data parallelism, where each batch is split into sub-batches which are evaluated on separate devices.\n", "\n", "There are two places to pay attention to:\n", "* the `update()` function\n", "* the replication of parameters and splitting of data across devices.\n", "\n", "If this example is too confusing, you can find the same example, but without parallelism, in the next notebook, [State in JAX](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). Once that example makes sense, you can compare the differences to understand how parallelism changes the picture." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "cI8xQqzRrc-4" }, "outputs": [], "source": [ "from typing import NamedTuple, Tuple\n", "import functools\n", "\n", "class Params(NamedTuple):\n", " weight: jnp.ndarray\n", " bias: jnp.ndarray\n", "\n", "\n", "def init(rng) -> Params:\n", " \"\"\"Returns the initial model params.\"\"\"\n", " weights_key, bias_key = jax.random.split(rng)\n", " weight = jax.random.normal(weights_key, ())\n", " bias = jax.random.normal(bias_key, ())\n", " return Params(weight, bias)\n", "\n", "\n", "def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:\n", " \"\"\"Computes the least squares error of the model's predictions on x against y.\"\"\"\n", " pred = params.weight * xs + params.bias\n", " return jnp.mean((pred - ys) ** 2)\n", "\n", "LEARNING_RATE = 0.005\n", "\n", "# So far, the code is identical to the single-device case. Here's what's new:\n", "\n", "\n", "# Remember that the `axis_name` is just an arbitrary string label used\n", "# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it\n", "# 'num_devices', but could have used anything, so long as `pmean` used the same.\n", "@functools.partial(jax.pmap, axis_name='num_devices')\n", "def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:\n", " \"\"\"Performs one SGD update step on params using the given data.\"\"\"\n", "\n", " # Compute the gradients on the given minibatch (individually on each device).\n", " loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)\n", "\n", " # Combine the gradient across all devices (by taking their mean).\n", " grads = jax.lax.pmean(grads, axis_name='num_devices')\n", "\n", " # Also combine the loss. Unnecessary for the update, but useful for logging.\n", " loss = jax.lax.pmean(loss, axis_name='num_devices')\n", "\n", " # Each device performs its own update, but since we start with the same params\n", " # and synchronise gradients, the params stay in sync.\n", " new_params = jax.tree_multimap(\n", " lambda param, g: param - g * LEARNING_RATE, params, grads)\n", "\n", " return new_params, loss" ] }, { "cell_type": "markdown", "metadata": { "id": "RWce8YZ4Pcmf" }, "source": [ "Here's how `update()` works:\n", "\n", "Undecorated and without the `pmean`s, `update()` takes data tensors of shape `[batch, ...]`, computes the loss function on that batch and evaluates its gradients.\n", "\n", "We want to spread the `batch` dimension across all available devices. To do that, we add a new axis using `pmap`. The arguments to the decorated `update()` thus need to have shape `[num_devices, batch_per_device, ...]`. So, to call the new `update()`, we'll need to reshape data batches so that what used to be `batch` is reshaped to `[num_devices, batch_per_device]`. That's what `split()` does below. Additionally, we'll need to replicate our model parameters, adding the `num_devices` axis. This reshaping is how a pmapped function knows which devices to send which data.\n", "\n", "At some point during the update step, we need to combine the gradients computed by each device -- otherwise, the updates performed by each device would be different. That's why we use `jax.lax.pmean` to compute the mean across the `num_devices` axis, giving us the average gradient of the batch. That average gradient is what we use to compute the update.\n", "\n", "Aside on naming: here, we use `num_devices` for the `axis_name` for didactic clarity while introducing `jax.pmap`. However, in some sense that is tautologous: any axis introduced by a pmap will represent a number of devices. Therefore, it's common to see the axis be named something semantically meaningful, like `batch`, `data` (signifying data parallelism) or `model` (signifying model parallelism)." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "_CTtLrsQ-0kK" }, "outputs": [], "source": [ "# Generate true data from y = w*x + b + noise\n", "true_w, true_b = 2, -1\n", "xs = np.random.normal(size=(128, 1))\n", "noise = 0.5 * np.random.normal(size=(128, 1))\n", "ys = xs * true_w + true_b + noise\n", "\n", "# Initialise parameters and replicate across devices.\n", "params = init(jax.random.PRNGKey(123))\n", "n_devices = jax.local_device_count()\n", "replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)" ] }, { "cell_type": "markdown", "metadata": { "id": "dmCMyLP9SV99" }, "source": [ "So far, we've just constructed arrays with an additional leading dimension. The params are all still all on the host (CPU). `pmap` will communicate them to the devices when `update()` is first called, and each copy will stay on its own device subsequently. You can tell because they are a DeviceArray, not a ShardedDeviceArray:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "YSCgHguTSdGW", "outputId": "a8bf28df-3747-4d49-e340-b7696cf0c27d" }, "outputs": [ { "data": { "text/plain": [ "jax.interpreters.xla._DeviceArray" ] }, "execution_count": 19, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "type(replicated_params.weight)" ] }, { "cell_type": "markdown", "metadata": { "id": "90VtjPbeY-hD" }, "source": [ "The params will become a ShardedDeviceArray when they are returned by our pmapped `update()` (see further down)." ] }, { "cell_type": "markdown", "metadata": { "id": "eGVKxk1CV-m1" }, "source": [ "We do the same to the data:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "vY61QJoFWCII", "outputId": "f436a15f-db97-44cc-df33-bbb4ff222987" }, "outputs": [ { "data": { "text/plain": [ "numpy.ndarray" ] }, "execution_count": 20, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "def split(arr):\n", " \"\"\"Splits the first axis of `arr` evenly across the number of devices.\"\"\"\n", " return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])\n", "\n", "# Reshape xs and ys for the pmapped `update()`.\n", "x_split = split(xs)\n", "y_split = split(ys)\n", "\n", "type(x_split)" ] }, { "cell_type": "markdown", "metadata": { "id": "RzfJ-oK5WERq" }, "source": [ "The data is just a reshaped vanilla NumPy array. Hence, it cannot be anywhere but on the host, as NumPy runs on CPU only. Since we never modify it, it will get sent to the device at each `update` call, like in a real pipeline where data is typically streamed from CPU to the device at each step." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "atOTi7EeSQw-", "outputId": "c8daf141-63c4-481f-afa5-684c5f7b698d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "after first `update()`, `replicated_params.weight` is a \n", "after first `update()`, `loss` is a \n", "after first `update()`, `x_split` is a \n", "Step 0, loss: 0.228\n", "Step 100, loss: 0.228\n", "Step 200, loss: 0.228\n", "Step 300, loss: 0.228\n", "Step 400, loss: 0.228\n", "Step 500, loss: 0.228\n", "Step 600, loss: 0.228\n", "Step 700, loss: 0.228\n", "Step 800, loss: 0.228\n", "Step 900, loss: 0.228\n" ] } ], "source": [ "def type_after_update(name, obj):\n", " print(f\"after first `update()`, `{name}` is a\", type(obj))\n", "\n", "# Actual training loop.\n", "for i in range(1000):\n", "\n", " # This is where the params and data gets communicated to devices:\n", " replicated_params, loss = update(replicated_params, x_split, y_split)\n", "\n", " # The returned `replicated_params` and `loss` are now both ShardedDeviceArrays,\n", " # indicating that they're on the devices.\n", " # `x_split`, of course, remains a NumPy array on the host.\n", " if i == 0:\n", " type_after_update('replicated_params.weight', replicated_params.weight)\n", " type_after_update('loss', loss)\n", " type_after_update('x_split', x_split)\n", "\n", " if i % 100 == 0:\n", " # Note that loss is actually an array of shape [num_devices], with identical\n", " # entries, because each device returns its copy of the loss.\n", " # So, we take the first element to print it.\n", " print(f\"Step {i:3d}, loss: {loss[0]:.3f}\")\n", "\n", "\n", "# Plot results.\n", "\n", "# Like the loss, the leaves of params have an extra leading dimension,\n", "# so we take the params from the first device.\n", "params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "rvVCACv9UZcF", "outputId": "5c472d0f-1236-401b-be55-86e3dc43875d" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.scatter(xs, ys)\n", "plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "4wFJcqbhbn81" }, "source": [ "## Aside: hosts and devices in JAX\n", "\n", "When running on TPU, the idea of a 'host' becomes important. A host is the CPU that manages several devices. A single host can only manage so many devices (usually 8), so when running very large parallel programs, multiple hosts are needed, and some finesse is required to manage them." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "3DO8NwW5hurX", "outputId": "6df0bdd7-fee2-4805-9bfe-38e41bdaeb50" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, host_id=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, host_id=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, host_id=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, host_id=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, host_id=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, host_id=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, host_id=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, host_id=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 24, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "jax.devices()" ] }, { "cell_type": "markdown", "metadata": { "id": "sJwayfCoy15a" }, "source": [ "When running on CPU you can always emulate an arbitrary number of devices with a nifty `--xla_force_host_platform_device_count` XLA flag, e.g. by executing the following before importing JAX:\n", "```python\n", "import os\n", "os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'\n", "jax.devices()\n", "```\n", "```\n", "[CpuDevice(id=0),\n", " CpuDevice(id=1),\n", " CpuDevice(id=2),\n", " CpuDevice(id=3),\n", " CpuDevice(id=4),\n", " CpuDevice(id=5),\n", " CpuDevice(id=6),\n", " CpuDevice(id=7)]\n", "```\n", "This is especially useful for debugging and testing locally or even for prototyping in Colab since a CPU runtime is faster to (re-)start." ] } ], "metadata": { "accelerator": "TPU", "colab": { "name": "JAX Parallelism", "provenance": [] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }