jax.lax.custom_root
Warning
This page was created from a pull request (#9655).
jax.lax.custom_rootΒΆ
- jax.lax.custom_root(f, initial_guess, solve, tangent_solve, has_aux=False)[source]ΒΆ
Differentiably solve for a roots of a function.
This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables from the provided function
fvia the implicit function theorem: https://en.wikipedia.org/wiki/Implicit_function_theorem- Parameters
f β function for which to find a root. Should accept a single argument, return a tree of arrays with the same structure as its input.
initial_guess β initial guess for a zero of f.
solve β
function to solve for the roots of f. Should take two positional arguments, f and initial_guess, and return a solution with the same structure as initial_guess such that func(solution) = 0. In other words, the following is assumed to be true (but not checked):
solution = solve(f, initial_guess) error = f(solution) assert all(error == 0)
tangent_solve β
function to solve the tangent system. Should take two positional arguments, a linear function
g(the functionflinearized at its root) and a tree of array(s)ywith the same structure as initial_guess, and return a solutionxsuch thatg(x)=y:For scalar
y, uselambda g, y: y / g(1.0).For vector
y, you could use a linear solve with the Jacobian, if dimensionality ofyis not too large:lambda g, y: np.linalg.solve(jacobian(g)(y), y).
has_aux β bool indicating whether the
solvefunction returns auxiliary data like solver diagnostics as a second argument.
- Returns
The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming
f(solve(f, initial_guess)) == 0.