jax.nn.initializers.he_uniform
Warning
This page was created from a pull request (#9655).
jax.nn.initializers.he_uniform¶
- jax.nn.initializers.he_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax._src.numpy.lax_numpy.float64'>)¶
Initializer capable of adapting its scale to the shape of the weights tensor.
With distribution=”truncated_normal” or “normal”, samples are drawn from a truncated/untruncated normal distribution with a mean of zero and a standard deviation (after truncation, if used) stddev = sqrt(scale / n), where n is: - number of input units in the weights tensor, if mode=”fan_in” - number of output units, if mode=”fan_out” - average of the numbers of input and output units, if mode=”fan_avg”
This initializer can be configured with in_axis, out_axis, and batch_axis to work with general convolutional or dense layers; axes that are not in any of those arguments are assumed to be the “receptive field” (convolution kernel spatial axes).
With distribution=”truncated_normal”, the absolute values of the samples are truncated below 2 standard deviations before truncation.
With distribution=”uniform”, samples are drawn from: - a uniform interval, if dtype is real - a uniform disk, if dtype is complex with a mean of zero and a standard deviation of stddev.
- Parameters
scale – scaling factor (positive float).
mode – one of “fan_in”, “fan_out”, and “fan_avg”.
distribution – random distribution to use. One of “truncated_normal”, “normal” and “uniform”.
in_axis – axis or sequence of axes of the input dimension in the weights tensor.
out_axis – axis or sequence of axes of the output dimension in the weights tensor.
batch_axis – axis or sequence of axes in the weight tensor that should be ignored.
dtype – the dtype of the weights.