def dropped_inputs(): """Forward pass with dropout.""" # Clip magnitude of dropout rate, where we get the dropout rate alpha from # the additive parameterization (Molchanov et al., 2017): for weight ~ # Normal(mu, sigma**2), the variance `sigma**2 = alpha * mu**2`. mean = self.kernel.distribution.mean() log_variance = tf.math.log(self.kernel.distribution.variance()) log_alpha = log_variance - tf.math.log( tf.square(mean) + tf.keras.backend.epsilon()) log_alpha = tf.clip_by_value(log_alpha, -8., 8.) log_variance = log_alpha + tf.math.log( tf.square(mean) + tf.keras.backend.epsilon()) if inputs.shape.ndims <= 2: means = tf.matmul(inputs, mean) stddevs = tf.sqrt( tf.matmul(tf.square(inputs), tf.exp(log_variance)) + tf.keras.backend.epsilon()) else: means = tf.tensordot(inputs, mean, [[-1], [0]]) stddevs = tf.sqrt( tf.tensordot(tf.square(inputs), tf.exp(log_variance), [[-1], [0]]) + tf.keras.backend.epsilon()) if self.use_bias: means = tf.nn.bias_add(means, self.bias) outputs = generated_random_variables.Normal(loc=means, scale=stddevs) if self.activation is not None: outputs = self.activation(outputs) return outputs
def dropped_inputs(): """Forward pass with dropout.""" # Clip magnitude of dropout rate, where we get the dropout rate alpha from # the additive parameterization (Molchanov et al., 2017): for weight ~ # Normal(mu, sigma**2), the variance `sigma**2 = alpha * mu**2`. mean = self.kernel.distribution.mean() log_variance = tf.math.log(self.kernel.distribution.variance()) log_alpha = log_variance - tf.math.log(tf.square(mean) + tf.keras.backend.epsilon()) log_alpha = tf.clip_by_value(log_alpha, -8., 8.) log_variance = log_alpha + tf.math.log(tf.square(mean) + tf.keras.backend.epsilon()) means = self._convolution_op(inputs, mean) stddevs = tf.sqrt( self._convolution_op(tf.square(inputs), tf.exp(log_variance)) + tf.keras.backend.epsilon()) if self.use_bias: if self.data_format == 'channels_first': means = tf.nn.bias_add(means, self.bias, data_format='NCHW') else: means = tf.nn.bias_add(means, self.bias, data_format='NHWC') outputs = generated_random_variables.Normal(loc=means, scale=stddevs) if self.activation is not None: outputs = self.activation(outputs) return outputs
def __call__(self, shape, dtype=None): if not self.built: self.build(shape, dtype) return generated_random_variables.Independent( generated_random_variables.Normal( loc=self.mean, scale=self.stddev).distribution, reinterpreted_batch_ndims=len(shape))
def __call__(self, x): """Computes regularization given an ed.Normal random variable as input.""" if not isinstance(x, random_variable.RandomVariable): raise ValueError('Input must be an ed.RandomVariable.') prior = generated_random_variables.Independent( generated_random_variables.Normal(loc=x.distribution.mean(), scale=self.stddev).distribution, reinterpreted_batch_ndims=len(x.distribution.event_shape)) regularization = x.distribution.kl_divergence(prior.distribution) return self.scale_factor * regularization
def call(self, inputs): """Computes regularization given an input ed.RandomVariable.""" if not isinstance(inputs, random_variable.RandomVariable): raise ValueError('Input must be an ed.RandomVariable.') stddev = self.stddev if self.stddev_constraint: stddev = self.stddev_constraint(stddev) prior = generated_random_variables.Independent( generated_random_variables.Normal( loc=self.mean, scale=stddev).distribution, reinterpreted_batch_ndims=len(inputs.distribution.event_shape)) regularization = inputs.distribution.kl_divergence(prior.distribution) return self.scale_factor * regularization
def call(self, inputs): if not isinstance(inputs, random_variable.RandomVariable): # Default to a unit normal, i.e., derived from mean squared error loss. inputs = generated_random_variables.Normal(loc=inputs, scale=1.) batch_size = tf.shape(inputs)[0] // 2 # TODO(trandustin): Depend on github's ed2 for indexing RVs. This is a hack. # _, _ = inputs[:batch_size], inputs[batch_size:] original_inputs = random_variable.RandomVariable( inputs.distribution[:batch_size], value=inputs.value[:batch_size]) perturbed_inputs = random_variable.RandomVariable( inputs.distribution[batch_size:], value=inputs.value[batch_size:]) loss = tf.reduce_sum( tfp.distributions.Normal(self.mean, self.stddev).kl_divergence( perturbed_inputs.distribution)) loss /= tf.cast(batch_size, dtype=tf.float32) self.add_loss(loss) return original_inputs
def call(self, inputs): if self.coeffs_mean is None and self.coeffs_precision_tril_op is None: # p(mean(ynew) | xnew) = Normal(ynew | mean = 0, variance = xnew xnew^T) predictive_mean = 0. predictive_variance = tf.reduce_sum(tf.square(inputs), axis=-1) else: # p(mean(ynew) | xnew, x, y) = Normal(ynew | # mean = xnew (1/noise_variance) (1/noise_variance x^T x + I)^{-1}x^T y, # variance = xnew (1/noise_variance x^T x + I)^{-1} xnew^T) predictive_mean = tf.einsum('nm,m->n', inputs, self.coeffs_mean) predictive_covariance = tf.matmul( inputs, self.coeffs_precision_tril_op.solve( self.coeffs_precision_tril_op.solve(inputs, adjoint_arg=True), adjoint=True)) predictive_variance = tf.linalg.tensor_diag_part( predictive_covariance) return generated_random_variables.Normal( loc=predictive_mean, scale=tf.sqrt(predictive_variance))