jax.numpy.inner
Warning
This page was created from a pull request (#9655).
jax.numpy.innerΒΆ
- jax.numpy.inner(a, b, *, precision=None)[source]ΒΆ
Inner product of two arrays.
LAX-backend implementation of
inner().In addition to the original NumPy arguments listed below, also supports
precisionfor extra control over matrix-multiplication precision on supported devices.precisionmay be set toNone, which means default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of twoPrecisionenums indicating separate precision for each argument.Original docstring below.
Ordinary inner product of vectors for 1-D arrays (without complex conjugation), in higher dimensions a sum product over the last axes.
- Parameters
a (array_like) β If a and b are nonscalar, their last dimensions must match.
b (array_like) β If a and b are nonscalar, their last dimensions must match.
- Returns
out β If a and b are both scalars or both 1-D arrays then a scalar is returned; otherwise an array is returned.
out.shape = (*a.shape[:-1], *b.shape[:-1])- Return type