jax.lax.convert_element_type

Warning

This page was created from a pull request (#9655).

jax.lax.convert_element_type¶

jax.lax.convert_element_type(operand, new_dtype)[source]¶

Elementwise cast.

Wraps XLA’s ConvertElementType operator, which performs an elementwise conversion from one type to another. Similar to a C++ static_cast.

Parameters
  • operand (Any) – an array or scalar value to be cast

  • new_dtype (Any) – a NumPy dtype representing the target type.

Return type

Any

Returns

An array with the same shape as operand, cast elementwise to new_dtype.