jax.lax.custom_root
Warning
This page was created from a pull request (#9655).
jax.lax.custom_rootΒΆ
- jax.lax.custom_root(f, initial_guess, solve, tangent_solve, has_aux=False)[source]ΒΆ
Differentiably solve for a roots of a function.
This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables from the provided function
f
via the implicit function theorem: https://en.wikipedia.org/wiki/Implicit_function_theorem- Parameters
f β function for which to find a root. Should accept a single argument, return a tree of arrays with the same structure as its input.
initial_guess β initial guess for a zero of f.
solve β
function to solve for the roots of f. Should take two positional arguments, f and initial_guess, and return a solution with the same structure as initial_guess such that func(solution) = 0. In other words, the following is assumed to be true (but not checked):
solution = solve(f, initial_guess) error = f(solution) assert all(error == 0)
tangent_solve β
function to solve the tangent system. Should take two positional arguments, a linear function
g
(the functionf
linearized at its root) and a tree of array(s)y
with the same structure as initial_guess, and return a solutionx
such thatg(x)=y
:For scalar
y
, uselambda g, y: y / g(1.0)
.For vector
y
, you could use a linear solve with the Jacobian, if dimensionality ofy
is not too large:lambda g, y: np.linalg.solve(jacobian(g)(y), y)
.
has_aux β bool indicating whether the
solve
function returns auxiliary data like solver diagnostics as a second argument.
- Returns
The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming
f(solve(f, initial_guess)) == 0
.