jax.numpy.tril
Warning
This page was created from a pull request (#9655).
jax.numpy.trilΒΆ
- jax.numpy.tril(m, k=0)[source]ΒΆ
Lower triangle of an array.
LAX-backend implementation of
tril()
.Original docstring below.
Return a copy of an array with elements above the k-th diagonal zeroed. For arrays with
ndim
exceeding 2, tril will apply to the final two axes.- Parameters
m (array_like, shape (..., M, N)) β Input array.
k (int, optional) β Diagonal above which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above.
- Returns
tril β Lower triangle of m, of same shape and data-type as m.
- Return type
ndarray, shape (β¦, M, N)