jax.nn.initializers package

Warning

This page was created from a pull request (#9655).

jax.nn.initializers package

Common neural network layer initializers, consistent with definitions used in Keras and Sonnet.

Initializers

This module provides common neural network layer initializers, consistent with definitions used in Keras and Sonnet.

zeros(key, shape[, dtype])

ones(key, shape[, dtype])

uniform([scale, dtype])

normal([stddev, dtype])

variance_scaling(scale, mode, distribution)

Initializer capable of adapting its scale to the shape of the weights tensor.

glorot_uniform([in_axis, out_axis, ...])

Initializer capable of adapting its scale to the shape of the weights tensor.

glorot_normal([in_axis, out_axis, ...])

Initializer capable of adapting its scale to the shape of the weights tensor.

lecun_uniform([in_axis, out_axis, ...])

Initializer capable of adapting its scale to the shape of the weights tensor.

lecun_normal([in_axis, out_axis, ...])

Initializer capable of adapting its scale to the shape of the weights tensor.

he_uniform([in_axis, out_axis, batch_axis, ...])

Initializer capable of adapting its scale to the shape of the weights tensor.

he_normal([in_axis, out_axis, batch_axis, dtype])

Initializer capable of adapting its scale to the shape of the weights tensor.