jax.device_count
Warning
This page was created from a pull request (#9655).
jax.device_count¶
- jax.device_count(backend=None)[source]¶
Returns the total number of devices.
On most platforms, this is the same as
jax.local_device_count()
. However, on multi-process platforms where different devices are associated with different processes, this will return the total number of devices across all processes.