Contents

jax.experimental.ann module

Contents

Warning

This page was created from a pull request (#9655).

jax.experimental.ann module¶

ANN (Approximate Nearest Neighbor) computes top-k with a configurable recall rate.

This package only optimizes the TPU backend. For other device types it fallbacks to sort and slice.

Usage:

import functools
import jax
from jax.experimental import ann

# MIPS := maximal inner product search
# Inputs:
#   qy: f32[qy_size, feature_dim]
#   db: f32[db_size, feature_dim]
#
# Returns:
#   (f32[qy_size, k], i32[qy_size, k])
@functools.partial(jax.jit, static_argnames=["k", "recall_target"])
def mips(qy, db, k=10, recall_target=0.95):
  dists = jax.lax.dot(qy, db.transpose())
  # Computes max_k along the last dimension
  # returns (f32[qy_size, k], i32[qy_size, k])
  return ann.approx_max_k(dists, k=k, recall_target=recall_target)

# Obtains the top-10 dot products and its offsets in db.
dot_products, neighbors = mips(qy, db, k=10)
# Computes the recall against the true neighbors.
recall = ann.ann_recall(neighbors, true_neighbors)

# Multi-core example
# Inputs:
#   qy: f32[num_devices, qy_size, feature_dim]
#   db: f32[num_devices, per_device_db_size, feature_dim]
#   db_offset: i32[num_devices]
#
# Returns:
#   (f32[qy_size, num_devices, k], i32[qy_size, num_devices, k])
@functools.partial(
    jax.pmap,
    # static args: db_size, k, recall_target
    static_broadcasted_argnums=[3, 4, 5],
    out_axes=(1, 1))
def pmap_mips(qy, db, db_offset, db_size, k, recall_target):
  dists = jax.lax.dot(qy, db.transpose())
  dists, neighbors = ann.approx_max_k(
      dists, k=k, recall_target=recall_target,
      reduction_input_size_override=db_size)
  return (dists, neighbors + db_offset)

# i32[qy_size, num_devices, k]
pmap_neighbors = pmap_mips(qy, db, db_offset, db_size, 10, 0.95)[1]
# i32[qy_size, num_devices * k]
neighbors = jax.lax.collapse(pmap_neighbors, start_dimension=1, stop_dimension=3)

Todos:

* On host top-k aggregation
* Inaccurate but fast differentiation

API¶

approx_max_k(operand, k[, ...])

Returns max k values and their indices of the operand.

approx_min_k(operand, k[, ...])

Returns min k values and their indices of the operand.