jax.experimental.ann.approx_max_k

Warning

This page was created from a pull request (#9655).

jax.experimental.ann.approx_max_kΒΆ

jax.experimental.ann.approx_max_k(operand, k, reduction_dimension=- 1, recall_target=0.95, reduction_input_size_override=- 1, aggregate_to_topk=True)[source]ΒΆ

Returns max k values and their indices of the operand.

Parameters
  • operand (Any) – Array to search for max-k.

  • k (int) – Specifies the number of max-k.

  • reduction_dimension (int) – Integer dimension along which to search. Default: -1.

  • recall_target (float) – Recall target for the approximation.

  • reduction_input_size_override (int) – When set to a positive value, it overrides the size determined by operands[reduction_dim] for evaluating the recall. This option is useful when the given operand is only a subset of the overall computation in SPMD or distributed pipelines, where the true input size cannot be deferred by the operand shape.

  • aggregate_to_topk (bool) – When true, aggregates approximate results to top-k. When false, returns the approximate results.

Returns

Max k values and their indices of the inputs.

Return type

Tuple[Array, Array]