jax.lax.top_k

Warning

This page was created from a pull request (#9655).

jax.lax.top_k¶

jax.lax.top_k(operand, k)[source]¶

Returns top k values and their indices along the last axis of operand.

Parameters
  • operand (Any) –

  • k (int) –

Return type

Tuple[Any, Any]