jax.distributed.initialize

Warning

This page was created from a pull request (#9655).

jax.distributed.initializeΒΆ

jax.distributed.initialize(coordinator_address, num_processes, process_id)[source]ΒΆ

Initialize distributed system for topology discovery.

Currently, calling initialize sets up the multi-host GPU backend, and is not required for CPU or TPU backends.

Parameters
  • coordinator_address (str) – IP address and port of the coordinator. The choice of port does not matter, so long as the port is available on the coordinator and all processes agree on the port.

  • num_processes (int) – Number of processes.

  • process_id (int) – Id of the current process.

Example:

Suppose there are two GPU hosts, and host 0 is the designated coordinator with address 10.0.0.1:1234. To initialize the GPU cluster, run the following commands before anything else.

On host 0:

>>> jax.distributed.initialize('10.0.0.1:1234', 2, 0)  

On host 1:

>>> jax.distributed.initialize('10.0.0.1:1234', 2, 1)