jax.jvp
Warning
This page was created from a pull request (#9655).
jax.jvpΒΆ
- jax.jvp(fun, primals, tangents, has_aux=False)[source]ΒΆ
Computes a (forward-mode) Jacobian-vector product of
fun.- Parameters
fun (
Callable) β Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.primals β The primal values at which the Jacobian of
funshould be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters offun.tangents β The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as
primals.has_aux (
bool) β Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
- Return type
- Returns
If
has_auxisFalse, returns a(primals_out, tangents_out)pair, whereprimals_outisfun(*primals), andtangents_outis the Jacobian-vector product offunctionevaluated atprimalswithtangents. Thetangents_outvalue has the same Python tree structure and shapes asprimals_out. Ifhas_auxisTrue, returns a(primals_out, tangents_out, aux)tuple whereauxis the auxiliary data returned byfun.
For example:
>>> import jax >>> >>> y, v = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(y) 0.09983342 >>> print(v) 0.19900084