예제 #1
0
    def call(self, inputs, **kwargs):
        kernel_sigma = K.softplus(self.kernel_rho)
        kernel_perturb = kernel_sigma * K.random_normal(self.kernel_mu.shape)
        kernel = self.kernel_mu + kernel_perturb

        if self.bias_distribution:
            bias_sigma = K.softplus(self.bias_rho)
            bias = self.bias_mu + bias_sigma * K.random_normal(
                self.bias_mu.shape)
        else:
            bias = self.bias_mu

        loss = self.kl_loss(kernel, self.kernel_mu, kernel_sigma)

        if self.bias_distribution:
            loss += self.kl_loss(bias, self.bias_mu, bias_sigma)

        self.add_loss(K.in_train_phase(loss, 0.0))

        input_shape = K.shape(inputs)
        batch_shape = input_shape[:-1]
        sign_input = rademacher.sample(input_shape)
        sign_output = rademacher.sample(
            K.concatenate(
                [batch_shape, K.expand_dims(self.units, 0)], axis=0))
        perturbed_inputs = K.dot(inputs * sign_input,
                                 kernel_perturb) * sign_output

        outputs = K.dot(inputs, self.kernel_mu)
        outputs += perturbed_inputs
        outputs += bias

        # This always produces stochastic outputs
        return self.activation(outputs)
예제 #2
0
    def call(self, inputs, **kwargs):
        kernel_sigma = K.softplus(self.kernel_rho)
        kernel = self.kernel_mu + kernel_sigma * K.random_normal(
            self.kernel_mu.shape)

        bias_sigma = K.softplus(self.bias_rho)
        bias = self.bias_mu + bias_sigma * K.random_normal(self.bias_mu.shape)

        loss = self.kl_loss(kernel, self.kernel_mu,
                            kernel_sigma) + self.kl_loss(
                                bias, self.bias_mu, bias_sigma)

        self.add_loss(K.in_train_phase(loss, 0.0))

        # This always produces stochastic outputs
        return self.activation(K.dot(inputs, kernel) + bias)
예제 #3
0
    def call(self, inputs):
        assert len(
            inputs
        ) == 2, "This layer requires exactly two inputs (mean and variance logits)"

        logit_mean, logit_var = inputs
        logit_std = self.preprocess_variance_input(logit_var)
        logit_shape = (K.shape(logit_mean)[0], self.num_samples,
                       K.shape(logit_mean)[-1])

        logit_mean = K.expand_dims(logit_mean, axis=1)
        logit_mean = K.repeat_elements(logit_mean, self.num_samples, axis=1)

        logit_std = K.expand_dims(logit_std, axis=1)
        logit_std = K.repeat_elements(logit_std, self.num_samples, axis=1)

        logit_samples = K.random_normal(logit_shape,
                                        mean=logit_mean,
                                        stddev=logit_std)

        # Apply max normalization for numerical stability
        logit_samples = logit_samples - K.max(
            logit_samples, axis=-1, keepdims=True)

        # Apply temperature scaling to logits
        logit_samples = logit_samples / self.temperature

        prob_samples = K.softmax(logit_samples, axis=-1)
        probs = K.mean(prob_samples, axis=1)

        # This is required due to approximation error, without it probabilities can sum to 1.01 or 0.99
        probs = probs / K.sum(probs, axis=-1, keepdims=True)

        return probs
예제 #4
0
 def sample_perturbation(self):
     return K.random_normal(self.shape, K.zeros(self.shape), self.std)
예제 #5
0
 def sample(self):
     return K.random_normal(self.shape, self.mean, self.std)