Exemple #1
0
def xavier_truncated_normal(shape: Tuple[int, ...],
                            lower: float = -2,
                            upper: float = 2,
                            gain: float = 1) -> JaxArray:
    """Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from
    `Understanding the difficulty of training deep feedforward neural networks
    <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_.

    Args:
        shape: shape of the output tensor.
        lower: lower truncation of the normal.
        upper: upper truncation of the normal.
        gain: optional scaling factor.

    Returns:
        Tensor initialized with truncated normal random variables with standard
        deviation (gain * xavier_normal_gain) and support [lower, upper].
    """
    truncated_std = scipy.stats.truncnorm.std(a=lower,
                                              b=upper,
                                              loc=0.,
                                              scale=1)
    stddev = gain * xavier_normal_gain(shape) / truncated_std
    return random.truncated_normal(shape,
                                   stddev=stddev,
                                   lower=lower,
                                   upper=upper)
Exemple #2
0
def kaiming_truncated_normal(shape: Tuple[int, ...],
                             lower: float = -2,
                             upper: float = 2,
                             gain: float = 1) -> JaxArray:
    """Returns a tensor with values assigned using Kaiming He truncated normal initializer from
    `Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
    <https://arxiv.org/abs/1502.01852>`_.

    Args:
        shape: shape of the output tensor.
        lower: lower truncation of the normal.
        upper: upper truncation of the normal.
        gain: optional scaling factor.

    Returns:
        Tensor initialized with truncated normal random variables with standard
        deviation (gain * kaiming_normal_gain) and support [lower, upper].
    """
    truncated_std = scipy.stats.truncnorm.std(a=lower,
                                              b=upper,
                                              loc=0.,
                                              scale=1)
    stddev = gain * kaiming_normal_gain(shape) / truncated_std
    return random.truncated_normal(shape,
                                   stddev=stddev,
                                   lower=lower,
                                   upper=upper)
Exemple #3
0
def truncated_normal(shape: Tuple[int, ...],
                     lower: float = -2,
                     upper: float = 2,
                     stddev: float = 1) -> JaxArray:
    """Returns a tensor with values assigned using truncated normal initialization.

    Args:
        shape: shape of the output tensor.
        lower: lower truncation of the normal.
        upper: upper truncation of the normal.
        stddev: expected standard deviation.

    Returns:
        Tensor initialized with truncated normal random variables with standard
        deviation stddev and support [lower, upper].
    """
    truncated_std = scipy.stats.truncnorm.std(a=lower,
                                              b=upper,
                                              loc=0.,
                                              scale=1)
    stddev /= truncated_std
    return random.truncated_normal(shape,
                                   stddev=stddev,
                                   lower=lower,
                                   upper=upper)