jax.vjp
Warning
This page was created from a pull request (#9655).
jax.vjpΒΆ
- jax.vjp(fun: Callable[[...], jax._src.api.T], *primals: Any, has_aux: Literal[False] = 'False', reduce_axes: Sequence[Any] = '()') Tuple[jax._src.api.T, Callable][source]ΒΆ
- jax.vjp(fun: Callable[[...], Tuple[jax._src.api.T, jax._src.api.U]], *primals: Any, has_aux: Literal[True], reduce_axes: Sequence[Any] = '()') Tuple[jax._src.api.T, Callable, jax._src.api.U]
- jax.vjp(fun: Callable[[...], jax._src.api.T], *primals: Any) Tuple[jax._src.api.T, Callable]
- jax.vjp(fun: Callable[[...], Any], *primals: Any, has_aux: bool, reduce_axes: Sequence[Any] = '()') Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]
Compute a (reverse-mode) vector-Jacobian product of
fun.grad()is implemented as a special case ofvjp().- Parameters
fun (
Callable) β Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.primals β A sequence of primal values at which the Jacobian of
funshould be evaluated. The length ofprimalsshould be equal to the number of positional parameters tofun. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.has_aux (
bool) β Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.reduce_axes β Optional, tuple of axis names. If an axis is listed here, and
funimplicitly broadcasts a value over that axis, the backward pass will perform apsumof the corresponding gradient. Otherwise, the VJP will be per-example over named axes. For example, if'batch'is a named batch axis,vjp(f, *args, reduce_axes=('batch',))will create a VJP function that sums over the batch whilevjp(f, *args)will create a per-example VJP.
- Return type
- Returns
If
has_auxisFalse, returns a(primals_out, vjpfun)pair, whereprimals_outisfun(*primals).vjpfunis a function from a cotangent vector with the same shape asprimals_outto a tuple of cotangent vectors with the same shape asprimals, representing the vector-Jacobian product offunevaluated atprimals. Ifhas_auxisTrue, returns a(primals_out, vjpfun, aux)tuple whereauxis the auxiliary data returned byfun.
>>> import jax >>> >>> def f(x, y): ... return jax.numpy.sin(x), jax.numpy.cos(y) ... >>> primals, f_vjp = jax.vjp(f, 0.5, 1.0) >>> xbar, ybar = f_vjp((-0.7, 0.3)) >>> print(xbar) -0.61430776 >>> print(ybar) -0.2524413