jax.experimental.ann.approx_max_k
Warning
This page was created from a pull request (#9655).
jax.experimental.ann.approx_max_kΒΆ
- jax.experimental.ann.approx_max_k(operand, k, reduction_dimension=- 1, recall_target=0.95, reduction_input_size_override=- 1, aggregate_to_topk=True)[source]ΒΆ
Returns max
k
values and their indices of theoperand
.- Parameters
operand (
Any
) β Array to search for max-k.k (
int
) β Specifies the number of max-k.reduction_dimension (
int
) β Integer dimension along which to search. Default: -1.recall_target (
float
) β Recall target for the approximation.reduction_input_size_override (
int
) β When set to a positive value, it overrides the size determined by operands[reduction_dim] for evaluating the recall. This option is useful when the given operand is only a subset of the overall computation in SPMD or distributed pipelines, where the true input size cannot be deferred by the operand shape.aggregate_to_topk (
bool
) β When true, aggregates approximate results to top-k. When false, returns the approximate results.
- Returns
Max k values and their indices of the inputs.
- Return type
Tuple[Array, Array]