Exemple #1
0
        def dropped_inputs():
            """Forward pass with dropout."""
            # Clip magnitude of dropout rate, where we get the dropout rate alpha from
            # the additive parameterization (Molchanov et al., 2017): for weight ~
            # Normal(mu, sigma**2), the variance `sigma**2 = alpha * mu**2`.
            mean = self.kernel.distribution.mean()
            log_variance = tf.math.log(self.kernel.distribution.variance())
            log_alpha = log_variance - tf.math.log(
                tf.square(mean) + tf.keras.backend.epsilon())
            log_alpha = tf.clip_by_value(log_alpha, -8., 8.)
            log_variance = log_alpha + tf.math.log(
                tf.square(mean) + tf.keras.backend.epsilon())

            if inputs.shape.ndims <= 2:
                means = tf.matmul(inputs, mean)
                stddevs = tf.sqrt(
                    tf.matmul(tf.square(inputs), tf.exp(log_variance)) +
                    tf.keras.backend.epsilon())
            else:
                means = tf.tensordot(inputs, mean, [[-1], [0]])
                stddevs = tf.sqrt(
                    tf.tensordot(tf.square(inputs), tf.exp(log_variance),
                                 [[-1], [0]]) + tf.keras.backend.epsilon())
            if self.use_bias:
                means = tf.nn.bias_add(means, self.bias)
            outputs = generated_random_variables.Normal(loc=means,
                                                        scale=stddevs)
            if self.activation is not None:
                outputs = self.activation(outputs)
            return outputs
Exemple #2
0
    def dropped_inputs():
      """Forward pass with dropout."""
      # Clip magnitude of dropout rate, where we get the dropout rate alpha from
      # the additive parameterization (Molchanov et al., 2017): for weight ~
      # Normal(mu, sigma**2), the variance `sigma**2 = alpha * mu**2`.
      mean = self.kernel.distribution.mean()
      log_variance = tf.math.log(self.kernel.distribution.variance())
      log_alpha = log_variance - tf.math.log(tf.square(mean) +
                                             tf.keras.backend.epsilon())
      log_alpha = tf.clip_by_value(log_alpha, -8., 8.)
      log_variance = log_alpha + tf.math.log(tf.square(mean) +
                                             tf.keras.backend.epsilon())

      means = self._convolution_op(inputs, mean)
      stddevs = tf.sqrt(
          self._convolution_op(tf.square(inputs), tf.exp(log_variance)) +
          tf.keras.backend.epsilon())
      if self.use_bias:
        if self.data_format == 'channels_first':
          means = tf.nn.bias_add(means, self.bias, data_format='NCHW')
        else:
          means = tf.nn.bias_add(means, self.bias, data_format='NHWC')
      outputs = generated_random_variables.Normal(loc=means, scale=stddevs)
      if self.activation is not None:
        outputs = self.activation(outputs)
      return outputs
Exemple #3
0
 def __call__(self, shape, dtype=None):
   if not self.built:
     self.build(shape, dtype)
   return generated_random_variables.Independent(
       generated_random_variables.Normal(
           loc=self.mean, scale=self.stddev).distribution,
       reinterpreted_batch_ndims=len(shape))
Exemple #4
0
 def __call__(self, x):
     """Computes regularization given an ed.Normal random variable as input."""
     if not isinstance(x, random_variable.RandomVariable):
         raise ValueError('Input must be an ed.RandomVariable.')
     prior = generated_random_variables.Independent(
         generated_random_variables.Normal(loc=x.distribution.mean(),
                                           scale=self.stddev).distribution,
         reinterpreted_batch_ndims=len(x.distribution.event_shape))
     regularization = x.distribution.kl_divergence(prior.distribution)
     return self.scale_factor * regularization
Exemple #5
0
 def call(self, inputs):
   """Computes regularization given an input ed.RandomVariable."""
   if not isinstance(inputs, random_variable.RandomVariable):
     raise ValueError('Input must be an ed.RandomVariable.')
   stddev = self.stddev
   if self.stddev_constraint:
     stddev = self.stddev_constraint(stddev)
   prior = generated_random_variables.Independent(
       generated_random_variables.Normal(
           loc=self.mean, scale=stddev).distribution,
       reinterpreted_batch_ndims=len(inputs.distribution.event_shape))
   regularization = inputs.distribution.kl_divergence(prior.distribution)
   return self.scale_factor * regularization
Exemple #6
0
 def call(self, inputs):
     if not isinstance(inputs, random_variable.RandomVariable):
         # Default to a unit normal, i.e., derived from mean squared error loss.
         inputs = generated_random_variables.Normal(loc=inputs, scale=1.)
     batch_size = tf.shape(inputs)[0] // 2
     # TODO(trandustin): Depend on github's ed2 for indexing RVs. This is a hack.
     # _, _ = inputs[:batch_size], inputs[batch_size:]
     original_inputs = random_variable.RandomVariable(
         inputs.distribution[:batch_size], value=inputs.value[:batch_size])
     perturbed_inputs = random_variable.RandomVariable(
         inputs.distribution[batch_size:], value=inputs.value[batch_size:])
     loss = tf.reduce_sum(
         tfp.distributions.Normal(self.mean, self.stddev).kl_divergence(
             perturbed_inputs.distribution))
     loss /= tf.cast(batch_size, dtype=tf.float32)
     self.add_loss(loss)
     return original_inputs
 def call(self, inputs):
     if self.coeffs_mean is None and self.coeffs_precision_tril_op is None:
         # p(mean(ynew) | xnew) = Normal(ynew | mean = 0, variance = xnew xnew^T)
         predictive_mean = 0.
         predictive_variance = tf.reduce_sum(tf.square(inputs), axis=-1)
     else:
         # p(mean(ynew) | xnew, x, y) = Normal(ynew |
         #   mean = xnew (1/noise_variance) (1/noise_variance x^T x + I)^{-1}x^T y,
         #   variance = xnew (1/noise_variance x^T x + I)^{-1} xnew^T)
         predictive_mean = tf.einsum('nm,m->n', inputs, self.coeffs_mean)
         predictive_covariance = tf.matmul(
             inputs,
             self.coeffs_precision_tril_op.solve(
                 self.coeffs_precision_tril_op.solve(inputs,
                                                     adjoint_arg=True),
                 adjoint=True))
         predictive_variance = tf.linalg.tensor_diag_part(
             predictive_covariance)
     return generated_random_variables.Normal(
         loc=predictive_mean, scale=tf.sqrt(predictive_variance))