def test_bias_selu_norm(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (100000, 128)) y, _ = activations.BiasSELUNorm.create( key2, x, features=128, bias_init=jax.nn.initializers.normal(stddev=.5), scale_init=normal(mean=new_initializers.inv_softplus(1), stddev=.5), ) mean = jnp.mean(y, axis=0) std = jnp.std(y, axis=0) onp.testing.assert_allclose(mean, jnp.zeros_like(mean), atol=1e-1) onp.testing.assert_allclose(std, jnp.ones_like(mean), atol=1e-1)
def scale_prior(): if config.scale_prior == 'normal': if config.activation_f in [ 'bias_scale_relu_norm', 'bias_scale_SELU_norm' ]: if config.softplus_scale: mean = new_initializers.inv_softplus(1.0) else: mean = 1.0 return functools.partial( new_regularizers.normal_prior_regularizer, scale=config.scale_prior_scale / train_size, mean=mean) else: return functools.partial( new_regularizers.normal_prior_regularizer, scale=config.scale_prior_scale / train_size, mean=1.0) elif config.scale_prior == 'none': return lambda x: 0. else: raise ValueError('Invalid value "%s" for config.scale_prior' % config.scale_prior)
def apply(self, inputs, features, bias=True, scale=False, dtype=jnp.float32, precision=None, bias_init=initializers.zeros, scale_init=None, softplus=True): if scale_init is None: if softplus: scale_init = new_initializers.init_softplus_ones else: scale_init = initializers.ones if bias: bias = self.param('bias', (features, ), bias_init) bias = jnp.asarray(bias, dtype) else: bias = 0. if scale: scale = self.param('scale', (features, ), scale_init) scale = jnp.asarray(scale, dtype) else: scale = float( new_initializers.inv_softplus(1.0)) if softplus else 1.0 if softplus: scale = nn.softplus(scale) y = inputs y *= scale y = y + bias relu_threshold = 0.0 y = jnp.maximum(relu_threshold, y) # Normalize y analytically. mean = bias std = scale var = std**2 # Kaiming initialized weights + bias + TLU # = mixture of delta peak + left-truncated gaussian # https://en.wikipedia.org/wiki/Mixture_distribution#Moments # https://en.wikipedia.org/wiki/Truncated_normal_distribution#One_sided_truncation_(of_lower_tail)[4] norm = jax.scipy.stats.norm t = (relu_threshold - mean) / std # If the distribution lies 4 stdev below the threshold, cap at t=4. t = jnp.minimum(4, t) z = 1 - norm.cdf(t) new_mean_non_cut = mean + (std * norm.pdf(t)) / z new_var_non_cut = (var) * (1 + t * norm.pdf(t) / z - (norm.pdf(t) / z)**2) # Psi function. # Compute mixture mean. new_mean = new_mean_non_cut * z + relu_threshold * norm.cdf(t) # Compute mixture variance. new_var = z * (new_var_non_cut + new_mean_non_cut**2 - new_mean**2) new_var += (1 - z) * (0 + relu_threshold**2 - new_mean**2) new_std = jnp.sqrt(new_var + 1e-8) new_std = jnp.maximum(0.01, new_std) # Normalize y. y_norm = y y_norm -= new_mean y_norm /= new_std return y_norm
def apply(self, inputs, features, bias=True, scale=True, dtype=jnp.float32, precision=None, bias_init=initializers.zeros, scale_init=None, softplus=True, norm_grad_block=False): if scale_init is None: if softplus: scale_init = new_initializers.init_softplus_ones else: scale_init = initializers.ones norm = jax.scipy.stats.norm erf = jax.scipy.special.erf # Error function. if bias: bias = self.param('bias', (features, ), bias_init) bias = jnp.asarray(bias, dtype) else: bias = 0. if scale: scale = self.param('scale', (features, ), scale_init) scale = jnp.asarray(scale, dtype) else: scale = float( new_initializers.inv_softplus(1.0)) if softplus else 1.0 if softplus: scale = nn.softplus(scale) pre = inputs pre *= scale pre = pre + bias y = jax.nn.selu(pre) # Compute moments based in learned scale/bias. if norm_grad_block: scale = jax.lax.stop_gradient(scale) bias = jax.lax.stop_gradient(bias) std = scale mean = bias var = std**2 # SELU magic numbers from SeLU paper [2] and jax.nn.selu. alpha = 1.6732632423543772848170429916717 selu_scale = 1.0507009873554804934193349852946 selu_threshold = 0 # Compute moments of left and right side of split gaussian for x <=0 & x > 0 t = (selu_threshold - mean) / std # If the distribution lies 4 stdev below the threshold, cap at t=4. t = jnp.maximum(-3, jnp.minimum(3, t)) z = 1 - norm.cdf(t) new_mean_right = (mean + (std * norm.pdf(t)) / z) new_var_right = (var) * (1 + t * norm.pdf(t) / z - (norm.pdf(t) / z)**2) l_scale = jnp.exp(mean) # Log normal scale parameter = exp(mean) log_scale = mean min_log = -5 # Compute truncated log normal statistics for left part of SELU. # TODO(basv): improve numerical errors with np.exp1m? a1 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (-var + min_log - log_scale) a2 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (var + log_scale - selu_threshold) a3 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (min_log - log_scale) a4 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (-selu_threshold + log_scale) a5 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (-2 * var + min_log - log_scale) a6 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (2 * var + log_scale - selu_threshold) e_a1 = erf(a1) e_a2 = erf(a2) e_a3 = erf(a3) e_a4 = erf(a4) e_a5 = erf(a5) e_a6 = erf(a6) exp_var = jnp.exp(var) # Equation 18 [1]. trunc_lognorm_mean = (l_scale * jnp.exp(.5 * var) * (e_a1 + e_a2)) / (e_a3 + e_a4 + 1e-5) trunc_lognorm_mean_m1 = trunc_lognorm_mean - 1 # selu uses e^x - 1 # Equation 20 [1]. n = exp_var * (e_a3 * e_a5 * exp_var + e_a3 * e_a6 * exp_var + e_a4 * e_a5 * exp_var + e_a4 * e_a6 * exp_var - e_a1**2 - 2 * e_a1 * e_a2 - e_a2**2) * l_scale**2 # Equation 19 [1]. trunc_lognorm_var = n / ((e_a3 + e_a4 + 1e-5)**2) selu_mean = alpha * trunc_lognorm_mean_m1 selu_var = alpha**2 * trunc_lognorm_var # Compute mixture mean multiplied by selu_scale. new_mean = (selu_mean * (1 - z) + new_mean_right * z) # Compute mixture variance. new_var = z * (new_var_right + new_mean_right**2 - new_mean**2) new_var += (1 - z) * (selu_var + selu_mean**2 - new_mean**2) new_mean = selu_scale * new_mean new_std = jnp.sqrt(new_var + 1e-5) * selu_scale new_var *= selu_scale**2 if norm_grad_block: new_mean = jax.lax.stop_gradient(new_mean) new_std = jax.lax.stop_gradient(new_std) new_std = jnp.maximum(1e-3, new_std) # Normalize y. y_norm = y y_norm -= new_mean y_norm /= new_std return y_norm