jax.value_and_grad
Warning
This page was created from a pull request (#9655).
jax.value_and_gradΒΆ
- jax.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]ΒΆ
Create a function that evaluates both
funand the gradient offun.- Parameters
fun (
Callable) β Function to be differentiated. Its arguments at positions specified byargnumsshould be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape()but not arrays with shape(1,)etc.)argnums (
Union[int,Sequence[int]]) β Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).has_aux (
bool) β Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool) β Optional, bool. Indicates whetherfunis promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int (
bool) β Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.reduce_axes (
Sequence[Any]) β Optional, tuple of axis names. If an axis is listed here, andfunimplicitly broadcasts a value over that axis, the backward pass will perform apsumof the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if'batch'is a named batch axis,value_and_grad(f, reduce_axes=('batch',))will create a function that computes the total gradient whilevalue_and_grad(f)will create one that computes the per-example gradient.
- Return type
- Returns
A function with the same arguments as
funthat evaluates bothfunand the gradient offunand returns them as a pair (a two-element tuple). Ifargnumsis an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. Ifhas_auxis True then a tuple of ((value, auxiliary_data), gradient) is returned.