jax.numpy.einsum
Warning
This page was created from a pull request (#9655).
jax.numpy.einsum¶
- jax.numpy.einsum(*operands, out=None, optimize='optimal', precision=None, _use_xeinsum=False)[source]¶
Evaluates the Einstein summation convention on the operands.
LAX-backend implementation of
einsum().In addition to the original NumPy arguments listed below, also supports
precisionfor extra control over matrix-multiplication precision on supported devices.precisionmay be set toNone, which means default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of twoPrecisionenums indicating separate precision for each argument. A tupleprecisiondoes not necessarily map to multiple arguments ofeinsum(); rather, the specifiedprecisionis forwarded to eachdot_generalcall used in the implementation.Original docstring below.
Using the Einstein summation convention, many common multi-dimensional, linear algebraic array operations can be represented in a simple fashion. In implicit mode einsum computes these values.
In explicit mode, einsum provides further flexibility to compute other array operations that might not be considered classical Einstein summation operations, by disabling, or forcing summation over specified subscript labels.
See the notes and examples for clarification.
- Parameters
operands (list of array_like) – These are the arrays for the operation.
optimize ({False, True, 'greedy', 'optimal'}, optional) – Controls if intermediate optimization should occur. No optimization will occur if False and True will default to the ‘greedy’ algorithm. Also accepts an explicit contraction list from the
np.einsum_pathfunction. Seenp.einsum_pathfor more details. Defaults to False.
- Returns
output – The calculation based on the Einstein summation convention.
- Return type