def xavier_truncated_normal(shape: Tuple[int, ...], lower: float = -2, upper: float = 2, gain: float = 1) -> JaxArray: """Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from `Understanding the difficulty of training deep feedforward neural networks <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_. Args: shape: shape of the output tensor. lower: lower truncation of the normal. upper: upper truncation of the normal. gain: optional scaling factor. Returns: Tensor initialized with truncated normal random variables with standard deviation (gain * xavier_normal_gain) and support [lower, upper]. """ truncated_std = scipy.stats.truncnorm.std(a=lower, b=upper, loc=0., scale=1) stddev = gain * xavier_normal_gain(shape) / truncated_std return random.truncated_normal(shape, stddev=stddev, lower=lower, upper=upper)
def kaiming_truncated_normal(shape: Tuple[int, ...], lower: float = -2, upper: float = 2, gain: float = 1) -> JaxArray: """Returns a tensor with values assigned using Kaiming He truncated normal initializer from `Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification <https://arxiv.org/abs/1502.01852>`_. Args: shape: shape of the output tensor. lower: lower truncation of the normal. upper: upper truncation of the normal. gain: optional scaling factor. Returns: Tensor initialized with truncated normal random variables with standard deviation (gain * kaiming_normal_gain) and support [lower, upper]. """ truncated_std = scipy.stats.truncnorm.std(a=lower, b=upper, loc=0., scale=1) stddev = gain * kaiming_normal_gain(shape) / truncated_std return random.truncated_normal(shape, stddev=stddev, lower=lower, upper=upper)
def truncated_normal(shape: Tuple[int, ...], lower: float = -2, upper: float = 2, stddev: float = 1) -> JaxArray: """Returns a tensor with values assigned using truncated normal initialization. Args: shape: shape of the output tensor. lower: lower truncation of the normal. upper: upper truncation of the normal. stddev: expected standard deviation. Returns: Tensor initialized with truncated normal random variables with standard deviation stddev and support [lower, upper]. """ truncated_std = scipy.stats.truncnorm.std(a=lower, b=upper, loc=0., scale=1) stddev /= truncated_std return random.truncated_normal(shape, stddev=stddev, lower=lower, upper=upper)