jax.numpy.tensordot
Warning
This page was created from a pull request (#9655).
jax.numpy.tensordot¶
- jax.numpy.tensordot(a, b, axes=2, *, precision=None)[source]¶
Compute tensor dot product along specified axes.
LAX-backend implementation of
tensordot().In addition to the original NumPy arguments listed below, also supports
precisionfor extra control over matrix-multiplication precision on supported devices.precisionmay be set toNone, which means default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of twoPrecisionenums indicating separate precision for each argument.Original docstring below.
Given two tensors, a and b, and an array_like object containing two array_like objects,
(a_axes, b_axes), sum the products of a’s and b’s elements (components) over the axes specified bya_axesandb_axes. The third argument can be a single non-negative integer_like scalar,N; if it is such, then the lastNdimensions of a and the firstNdimensions of b are summed over.- Parameters
a (array_like) – Tensors to “dot”.
b (array_like) – Tensors to “dot”.
axes (int or (2,) array_like) –
integer_like If an int N, sum over the last N axes of a and the first N axes of b in order. The sizes of the corresponding axes must match.
(2,) array_like Or, a list of axes to be summed over, first sequence applying to a, second to b. Both elements array_like must be of the same length.
- Returns
output – The tensor dot product of the input.
- Return type