Exemple #1
0
  def test_bias_selu_norm(self):
    rng = random.PRNGKey(0)
    key1, key2 = random.split(rng)
    x = random.normal(key1, (100000, 128))
    y, _ = activations.BiasSELUNorm.create(
        key2,
        x,
        features=128,
        bias_init=jax.nn.initializers.normal(stddev=.5),
        scale_init=normal(mean=new_initializers.inv_softplus(1), stddev=.5),
    )
    mean = jnp.mean(y, axis=0)
    std = jnp.std(y, axis=0)

    onp.testing.assert_allclose(mean, jnp.zeros_like(mean), atol=1e-1)
    onp.testing.assert_allclose(std, jnp.ones_like(mean), atol=1e-1)
 def scale_prior():
   if config.scale_prior == 'normal':
     if config.activation_f in [
         'bias_scale_relu_norm', 'bias_scale_SELU_norm'
     ]:
       if config.softplus_scale:
         mean = new_initializers.inv_softplus(1.0)
       else:
         mean = 1.0
       return functools.partial(
           new_regularizers.normal_prior_regularizer,
           scale=config.scale_prior_scale / train_size,
           mean=mean)
     else:
       return functools.partial(
           new_regularizers.normal_prior_regularizer,
           scale=config.scale_prior_scale / train_size,
           mean=1.0)
   elif config.scale_prior == 'none':
     return lambda x: 0.
   else:
     raise ValueError('Invalid value "%s" for config.scale_prior' %
                      config.scale_prior)
    def apply(self,
              inputs,
              features,
              bias=True,
              scale=False,
              dtype=jnp.float32,
              precision=None,
              bias_init=initializers.zeros,
              scale_init=None,
              softplus=True):
        if scale_init is None:
            if softplus:
                scale_init = new_initializers.init_softplus_ones
            else:
                scale_init = initializers.ones
        if bias:
            bias = self.param('bias', (features, ), bias_init)
            bias = jnp.asarray(bias, dtype)
        else:
            bias = 0.

        if scale:
            scale = self.param('scale', (features, ), scale_init)
            scale = jnp.asarray(scale, dtype)
        else:
            scale = float(
                new_initializers.inv_softplus(1.0)) if softplus else 1.0

        if softplus:
            scale = nn.softplus(scale)

        y = inputs
        y *= scale
        y = y + bias
        relu_threshold = 0.0
        y = jnp.maximum(relu_threshold, y)

        # Normalize y analytically.
        mean = bias
        std = scale
        var = std**2
        # Kaiming initialized weights + bias + TLU
        # = mixture of delta peak + left-truncated gaussian
        # https://en.wikipedia.org/wiki/Mixture_distribution#Moments
        # https://en.wikipedia.org/wiki/Truncated_normal_distribution#One_sided_truncation_(of_lower_tail)[4]
        norm = jax.scipy.stats.norm
        t = (relu_threshold - mean) / std

        # If the distribution lies 4 stdev below the threshold, cap at t=4.
        t = jnp.minimum(4, t)
        z = 1 - norm.cdf(t)

        new_mean_non_cut = mean + (std * norm.pdf(t)) / z
        new_var_non_cut = (var) * (1 + t * norm.pdf(t) / z -
                                   (norm.pdf(t) / z)**2)

        # Psi function.
        # Compute mixture mean.
        new_mean = new_mean_non_cut * z + relu_threshold * norm.cdf(t)
        # Compute mixture variance.
        new_var = z * (new_var_non_cut + new_mean_non_cut**2 - new_mean**2)
        new_var += (1 - z) * (0 + relu_threshold**2 - new_mean**2)
        new_std = jnp.sqrt(new_var + 1e-8)
        new_std = jnp.maximum(0.01, new_std)

        # Normalize y.
        y_norm = y
        y_norm -= new_mean
        y_norm /= new_std
        return y_norm
    def apply(self,
              inputs,
              features,
              bias=True,
              scale=True,
              dtype=jnp.float32,
              precision=None,
              bias_init=initializers.zeros,
              scale_init=None,
              softplus=True,
              norm_grad_block=False):
        if scale_init is None:
            if softplus:
                scale_init = new_initializers.init_softplus_ones
            else:
                scale_init = initializers.ones

        norm = jax.scipy.stats.norm
        erf = jax.scipy.special.erf  # Error function.

        if bias:
            bias = self.param('bias', (features, ), bias_init)
            bias = jnp.asarray(bias, dtype)
        else:
            bias = 0.

        if scale:
            scale = self.param('scale', (features, ), scale_init)
            scale = jnp.asarray(scale, dtype)
        else:
            scale = float(
                new_initializers.inv_softplus(1.0)) if softplus else 1.0

        if softplus:
            scale = nn.softplus(scale)

        pre = inputs
        pre *= scale
        pre = pre + bias
        y = jax.nn.selu(pre)

        # Compute moments based in learned scale/bias.
        if norm_grad_block:
            scale = jax.lax.stop_gradient(scale)
            bias = jax.lax.stop_gradient(bias)
        std = scale
        mean = bias
        var = std**2

        # SELU magic numbers from SeLU paper [2] and jax.nn.selu.
        alpha = 1.6732632423543772848170429916717
        selu_scale = 1.0507009873554804934193349852946
        selu_threshold = 0

        # Compute moments of left and right side of split gaussian for x <=0 & x > 0
        t = (selu_threshold - mean) / std
        # If the distribution lies 4 stdev below the threshold, cap at t=4.
        t = jnp.maximum(-3, jnp.minimum(3, t))
        z = 1 - norm.cdf(t)
        new_mean_right = (mean + (std * norm.pdf(t)) / z)
        new_var_right = (var) * (1 + t * norm.pdf(t) / z -
                                 (norm.pdf(t) / z)**2)

        l_scale = jnp.exp(mean)  # Log normal scale parameter = exp(mean)
        log_scale = mean
        min_log = -5

        # Compute truncated log normal statistics for left part of SELU.
        # TODO(basv): improve numerical errors with np.exp1m?
        a1 = .5 * (1. /
                   (std + 1e-5)) * jnp.sqrt(2) * (-var + min_log - log_scale)
        a2 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (var + log_scale -
                                                       selu_threshold)
        a3 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (min_log - log_scale)
        a4 = .5 * (1. /
                   (std + 1e-5)) * jnp.sqrt(2) * (-selu_threshold + log_scale)
        a5 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (-2 * var + min_log -
                                                       log_scale)
        a6 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (2 * var + log_scale -
                                                       selu_threshold)
        e_a1 = erf(a1)
        e_a2 = erf(a2)
        e_a3 = erf(a3)
        e_a4 = erf(a4)
        e_a5 = erf(a5)
        e_a6 = erf(a6)
        exp_var = jnp.exp(var)

        # Equation 18 [1].
        trunc_lognorm_mean = (l_scale * jnp.exp(.5 * var) *
                              (e_a1 + e_a2)) / (e_a3 + e_a4 + 1e-5)
        trunc_lognorm_mean_m1 = trunc_lognorm_mean - 1  # selu uses e^x - 1
        # Equation 20 [1].
        n = exp_var * (e_a3 * e_a5 * exp_var + e_a3 * e_a6 * exp_var +
                       e_a4 * e_a5 * exp_var + e_a4 * e_a6 * exp_var -
                       e_a1**2 - 2 * e_a1 * e_a2 - e_a2**2) * l_scale**2
        # Equation 19 [1].
        trunc_lognorm_var = n / ((e_a3 + e_a4 + 1e-5)**2)

        selu_mean = alpha * trunc_lognorm_mean_m1
        selu_var = alpha**2 * trunc_lognorm_var

        # Compute mixture mean multiplied by selu_scale.
        new_mean = (selu_mean * (1 - z) + new_mean_right * z)

        # Compute mixture variance.
        new_var = z * (new_var_right + new_mean_right**2 - new_mean**2)
        new_var += (1 - z) * (selu_var + selu_mean**2 - new_mean**2)
        new_mean = selu_scale * new_mean
        new_std = jnp.sqrt(new_var + 1e-5) * selu_scale
        new_var *= selu_scale**2

        if norm_grad_block:
            new_mean = jax.lax.stop_gradient(new_mean)
            new_std = jax.lax.stop_gradient(new_std)

        new_std = jnp.maximum(1e-3, new_std)

        # Normalize y.
        y_norm = y
        y_norm -= new_mean
        y_norm /= new_std
        return y_norm