jax.lib.xla_bridge.get_compile_options
Warning
This page was created from a pull request (#9655).
jax.lib.xla_bridge.get_compile_optionsΒΆ
- jax.lib.xla_bridge.get_compile_options(num_replicas, num_partitions, device_assignment=None, use_spmd_partitioning=True)[source]ΒΆ
Returns the compile options to use, as derived from flag values.
- Parameters
num_replicas (
int
) β Number of replicas for which to compile.num_partitions (
int
) β Number of partitions for which to compile.device_assignment β Optional ndarray of jax devices indicating the assignment of logical replicas to physical devices (default inherited from xla_client.CompileOptions). Must be consistent with num_replicas and num_partitions.
use_spmd_partitioning (
bool
) β boolean indicating whether to enable SPMD or MPMD partitioning in XLA.
- Return type
CompileOptions