jax.experimental.ann module
Contents
Warning
This page was created from a pull request (#9655).
jax.experimental.ann module¶
ANN (Approximate Nearest Neighbor) computes top-k with a configurable recall rate.
This package only optimizes the TPU backend. For other device types it fallbacks to sort and slice.
Usage:
import functools
import jax
from jax.experimental import ann
# MIPS := maximal inner product search
# Inputs:
# qy: f32[qy_size, feature_dim]
# db: f32[db_size, feature_dim]
#
# Returns:
# (f32[qy_size, k], i32[qy_size, k])
@functools.partial(jax.jit, static_argnames=["k", "recall_target"])
def mips(qy, db, k=10, recall_target=0.95):
dists = jax.lax.dot(qy, db.transpose())
# Computes max_k along the last dimension
# returns (f32[qy_size, k], i32[qy_size, k])
return ann.approx_max_k(dists, k=k, recall_target=recall_target)
# Obtains the top-10 dot products and its offsets in db.
dot_products, neighbors = mips(qy, db, k=10)
# Computes the recall against the true neighbors.
recall = ann.ann_recall(neighbors, true_neighbors)
# Multi-core example
# Inputs:
# qy: f32[num_devices, qy_size, feature_dim]
# db: f32[num_devices, per_device_db_size, feature_dim]
# db_offset: i32[num_devices]
#
# Returns:
# (f32[qy_size, num_devices, k], i32[qy_size, num_devices, k])
@functools.partial(
jax.pmap,
# static args: db_size, k, recall_target
static_broadcasted_argnums=[3, 4, 5],
out_axes=(1, 1))
def pmap_mips(qy, db, db_offset, db_size, k, recall_target):
dists = jax.lax.dot(qy, db.transpose())
dists, neighbors = ann.approx_max_k(
dists, k=k, recall_target=recall_target,
reduction_input_size_override=db_size)
return (dists, neighbors + db_offset)
# i32[qy_size, num_devices, k]
pmap_neighbors = pmap_mips(qy, db, db_offset, db_size, 10, 0.95)[1]
# i32[qy_size, num_devices * k]
neighbors = jax.lax.collapse(pmap_neighbors, start_dimension=1, stop_dimension=3)
Todos:
* On host top-k aggregation
* Inaccurate but fast differentiation
API¶
|
Returns max |
|
Returns min |