jax.lax.cond
Warning
This page was created from a pull request (#9655).
jax.lax.condΒΆ
- jax.lax.cond(pred, true_fun, false_fun, *operands, operand=<object object>)[source]ΒΆ
Conditionally apply
true_funorfalse_fun.cond()has equivalent semantics to this Python implementation:def cond(pred, true_fun, false_fun, *operands): if pred: return true_fun(*operands) else: return false_fun(*operands)
predmust be a scalar type.- Parameters
pred β Boolean scalar type, indicating which branch function to apply.
true_fun (
Callable) β Function (A -> B), to be applied ifpredis True.false_fun (
Callable) β Function (A -> B), to be applied ifpredis False.operands β Operands (A) input to either branch depending on
pred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.
- Returns
Value (B) of either
true_fun(*operands)orfalse_fun(*operands), depending on the value ofpred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.