Source code for jax.experimental.ann

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ANN (Approximate Nearest Neighbor) computes top-k with a configurable recall rate.

This package only optimizes the TPU backend. For other device types it fallbacks
to sort and slice.

Usage::

  import functools
  import jax
  from jax.experimental import ann

  # MIPS := maximal inner product search
  # Inputs:
  #   qy: f32[qy_size, feature_dim]
  #   db: f32[db_size, feature_dim]
  #
  # Returns:
  #   (f32[qy_size, k], i32[qy_size, k])
  @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
  def mips(qy, db, k=10, recall_target=0.95):
    dists = jax.lax.dot(qy, db.transpose())
    # Computes max_k along the last dimension
    # returns (f32[qy_size, k], i32[qy_size, k])
    return ann.approx_max_k(dists, k=k, recall_target=recall_target)

  # Obtains the top-10 dot products and its offsets in db.
  dot_products, neighbors = mips(qy, db, k=10)
  # Computes the recall against the true neighbors.
  recall = ann.ann_recall(neighbors, true_neighbors)

  # Multi-core example
  # Inputs:
  #   qy: f32[num_devices, qy_size, feature_dim]
  #   db: f32[num_devices, per_device_db_size, feature_dim]
  #   db_offset: i32[num_devices]
  #
  # Returns:
  #   (f32[qy_size, num_devices, k], i32[qy_size, num_devices, k])
  @functools.partial(
      jax.pmap,
      # static args: db_size, k, recall_target
      static_broadcasted_argnums=[3, 4, 5],
      out_axes=(1, 1))
  def pmap_mips(qy, db, db_offset, db_size, k, recall_target):
    dists = jax.lax.dot(qy, db.transpose())
    dists, neighbors = ann.approx_max_k(
        dists, k=k, recall_target=recall_target,
        reduction_input_size_override=db_size)
    return (dists, neighbors + db_offset)

  # i32[qy_size, num_devices, k]
  pmap_neighbors = pmap_mips(qy, db, db_offset, db_size, 10, 0.95)[1]
  # i32[qy_size, num_devices * k]
  neighbors = jax.lax.collapse(pmap_neighbors, start_dimension=1, stop_dimension=3)

Todos::

  * On host top-k aggregation
  * Inaccurate but fast differentiation

"""

from functools import partial
from typing import (Any, Tuple)

import numpy as np
from jax import lax, core
from jax._src.lib import xla_client as xc
from jax._src import ad_util, dtypes

from jax.interpreters import ad, xla, batching

Array = Any


[docs]def approx_max_k(operand: Array, k: int, reduction_dimension: int = -1, recall_target: float = 0.95, reduction_input_size_override: int = -1, aggregate_to_topk: bool = True) -> Tuple[Array, Array]: """Returns max ``k`` values and their indices of the ``operand``. Args: operand : Array to search for max-k. k : Specifies the number of max-k. reduction_dimension : Integer dimension along which to search. Default: -1. recall_target : Recall target for the approximation. reduction_input_size_override : When set to a positive value, it overrides the size determined by operands[reduction_dim] for evaluating the recall. This option is useful when the given operand is only a subset of the overall computation in SPMD or distributed pipelines, where the true input size cannot be deferred by the operand shape. aggregate_to_topk: When true, aggregates approximate results to top-k. When false, returns the approximate results. Returns: Tuple[Array, Array] : Max k values and their indices of the inputs. """ if xc._version < 45: aggregate_to_topk = True return approx_top_k_p.bind( operand, k=k, reduction_dimension=reduction_dimension, recall_target=recall_target, is_max_k=True, reduction_input_size_override=reduction_input_size_override, aggregate_to_topk=aggregate_to_topk)
[docs]def approx_min_k(operand: Array, k: int, reduction_dimension: int = -1, recall_target: float = 0.95, reduction_input_size_override: int = -1, aggregate_to_topk: bool = True) -> Tuple[Array, Array]: """Returns min ``k`` values and their indices of the ``operand``. Args: operand : Array to search for min-k. k : Specifies the number of min-k. reduction_dimension: Integer dimension along which to search. Default: -1. recall_target: Recall target for the approximation. reduction_input_size_override : When set to a positive value, it overrides the size determined by operands[reduction_dim] for evaluating the recall. This option is useful when the given operand is only a subset of the overall computation in SPMD or distributed pipelines, where the true input size cannot be deferred by the operand shape. aggregate_to_topk: When true, aggregates approximate results to top-k. When false, returns the approximate results. Returns: Tuple[Array, Array] : Least k values and their indices of the inputs. """ if xc._version < 45: aggregate_to_topk = True return approx_top_k_p.bind( operand, k=k, reduction_dimension=reduction_dimension, recall_target=recall_target, is_max_k=False, reduction_input_size_override=reduction_input_size_override, aggregate_to_topk=aggregate_to_topk)
def ann_recall(result_neighbors: Array, ground_truth_neighbors: Array) -> float: """Computes the recall of an approximate nearest neighbor search. Note, this is NOT a JAX-compatible op. This is a host function that only takes numpy arrays as the inputs. Args: result_neighbors: int32 numpy array of the shape [num_queries, neighbors_per_query] where the values are the indices of the dataset. ground_truth_neighbors: int32 numpy array of with shape [num_queries, ground_truth_neighbors_per_query] where the values are the indices of the dataset. Example: # shape [num_queries, 100] ground_truth_neighbors = ... # pre-computed top 100 sorted neighbors # shape [num_queries, 200] result_neighbors = ... # Retrieves 200 neighbors from some ann algorithms. # Computes the recall with k = 100 ann_recall(result_neighbors, ground_truth_neighbors) # Computes the recall with k = 10 ann_recall(result_neighbors, ground_truth_neighbors[:,0:10]) Returns: The recall. """ assert len(result_neighbors.shape) == 2, 'shape = [num_queries, neighbors_per_query]' assert len(ground_truth_neighbors.shape) == 2, 'shape = [num_queries, ground_truth_neighbors_per_query]' assert result_neighbors.shape[0] == ground_truth_neighbors.shape[0] gt_sets = [set(np.asarray(x)) for x in ground_truth_neighbors] hits = sum( len(list(x for x in nn_per_q if x.item() in gt_sets[q])) for q, nn_per_q in enumerate(result_neighbors)) return hits / ground_truth_neighbors.size def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, recall_target, is_max_k, reduction_input_size_override, aggregate_to_topk): if k <= 0: raise ValueError('k must be positive, got {}'.format(k)) if len(operand.shape) == 0: raise TypeError('approx_top_k operand must have >= 1 dimension, got {}'.format( operand.shape)) dims = list(operand.shape) if dims[reduction_dimension] < k: raise ValueError( 'k must be smaller than the size of reduction_dim {}, got {}'.format( dims[reduction_dimension], k)) if xc._version >= 45: reduction_input_size = dims[reduction_dimension] dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize( reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, reduction_input_size_override)[0] else: dims[reduction_dimension] = k return (operand.update( shape=dims, dtype=operand.dtype, weak_type=operand.weak_type), operand.update(shape=dims, dtype=np.dtype(np.int32))) def _comparator_builder(operand, op_type, is_max_k): c = xc.XlaBuilder( 'top_k_{}_comparator'.format('gt' if is_max_k else 'lt')) p0 = xla.parameter(c, 0, xc.Shape.scalar_shape(op_type)) p1 = xla.parameter(c, 1, xc.Shape.scalar_shape(op_type)) xla.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32))) xla.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32))) if is_max_k: cmp_result = xc.ops.Gt(p0, p1) else: cmp_result = xc.ops.Lt(p0, p1) return c.build(cmp_result) def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k, reduction_dimension, recall_target, is_max_k, reduction_input_size_override, aggregate_to_topk): c = ctx.builder op_shape = c.get_shape(operand) if not op_shape.is_array(): raise ValueError('operand must be an array, but was {}'.format(op_shape)) op_dims = op_shape.dimensions() op_type = op_shape.element_type() if reduction_dimension < 0: reduction_dimension = len(op_dims) + reduction_dimension comparator = _comparator_builder(operand, op_type, is_max_k) if is_max_k: if dtypes.issubdtype(op_type, np.floating): init_literal = np.array(np.NINF, dtype=op_type) else: init_literal = np.iinfo(op_type).min() else: if dtypes.issubdtype(op_type, np.floating): init_literal = np.array(np.Inf, dtype=op_type) else: init_literal = np.iinfo(op_type).max() iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims), reduction_dimension) init_val = xc.ops.Constant(c, init_literal) init_arg = xc.ops.Constant(c, np.int32(-1)) out = xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k, reduction_dimension, comparator, recall_target, aggregate_to_topk, reduction_input_size_override) return xla.xla_destructure(c, out) def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k, reduction_dimension, recall_target, is_max_k, reduction_input_size_override, aggregate_to_topk): c = ctx.builder op_shape = c.get_shape(operand) if not op_shape.is_array(): raise ValueError('operand must be an array, but was {}'.format(op_shape)) op_dims = op_shape.dimensions() op_type = op_shape.element_type() if reduction_dimension < 0: reduction_dimension = len(op_dims) + reduction_dimension comparator = _comparator_builder(operand, op_type, is_max_k) iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims), reduction_dimension) val_arg = xc.ops.Sort(c, [operand, iota], comparator, reduction_dimension) vals = xc.ops.GetTupleElement(val_arg, 0) args = xc.ops.GetTupleElement(val_arg, 1) sliced_vals = xc.ops.SliceInDim(vals, 0, avals_out[0].shape[reduction_dimension], 1, reduction_dimension) sliced_args = xc.ops.SliceInDim(args, 0, avals_out[0].shape[reduction_dimension], 1, reduction_dimension) return sliced_vals, sliced_args def _approx_top_k_batch_rule(batched_args, batch_dims, *, k, reduction_dimension, recall_target, is_max_k, reduction_input_size_override, aggregate_to_topk): prototype_arg, new_bdim = next( (a, b) for a, b in zip(batched_args, batch_dims) if b is not None) new_args = [] for arg, bdim in zip(batched_args, batch_dims): if bdim is None: dims = np.delete(np.arange(prototype_arg.ndim), new_bdim) new_args.append(lax.broadcast_in_dim(arg, prototype_arg.shape, dims)) else: new_args.append(batching.moveaxis(arg, bdim, new_bdim)) new_reduction_dim = reduction_dimension + (new_bdim <= reduction_dimension) bdims = (new_bdim,) * len(new_args) return (approx_top_k_p.bind( *new_args, k=k, reduction_dimension=new_reduction_dim, recall_target=recall_target, is_max_k=False, reduction_input_size_override=reduction_input_size_override, aggregate_to_topk=aggregate_to_topk), bdims) # Slow jvp implementation using gather. # # TODO(fchern): Some optimization ideas # 1. ApproxTopK is internally a variadic reduce, so we can simply call # ApproxTopK(operand, tangent, iota) for jvp. # 2. vjp cannot benefit from the algorithm above. We must run scatter to # distribute the output cotangent to input cotangent. A reasonable way to do # this is to run it on CPU. def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension, recall_target, is_max_k, reduction_input_size_override, aggregate_to_topk): operand, = primals tangent, = tangents if is_max_k: val_out, arg_out = approx_max_k(operand, k, reduction_dimension, recall_target, reduction_input_size_override, aggregate_to_topk) else: val_out, arg_out = approx_min_k(operand, k, reduction_dimension, recall_target, reduction_input_size_override, aggregate_to_topk) if type(tangent) is ad_util.Zero: tangent_out = ad_util.Zero.from_value(val_out) else: arg_shape = arg_out.shape rank = len(arg_shape) if reduction_dimension < 0: reduction_dimension += rank iotas = [ lax.broadcasted_iota(arg_out.dtype, arg_shape, i) for i in range(rank) ] idx = tuple( arg_out if i == reduction_dimension else iotas[i] for i in range(rank)) tangent_out = tangent[idx] return (val_out, arg_out), (tangent_out, ad_util.Zero.from_value(arg_out)) approx_top_k_p = core.Primitive('approx_top_k') approx_top_k_p.multiple_results = True approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p)) approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval) xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation) xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation, platform='tpu') batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule ad.primitive_jvps[approx_top_k_p] = _approx_top_k_jvp