示例#1
0
 def call(self, inputs):
     if not isinstance(inputs, random_variable.RandomVariable):
         # Default to a unit normal, i.e., derived from mean squared error loss.
         inputs = generated_random_variables.Normal(loc=inputs, scale=1.)
     batch_size = tf.shape(inputs)[0] // 2
     # TODO(trandustin): Depend on github's ed2 for indexing RVs. This is a hack.
     # _, _ = inputs[:batch_size], inputs[batch_size:]
     original_inputs = random_variable.RandomVariable(
         inputs.distribution[:batch_size], value=inputs.value[:batch_size])
     perturbed_inputs = random_variable.RandomVariable(
         inputs.distribution[batch_size:], value=inputs.value[batch_size:])
     loss = tf.reduce_sum(
         tfp.distributions.Normal(self.mean, self.stddev).kl_divergence(
             perturbed_inputs.distribution))
     loss /= tf.cast(batch_size, dtype=tf.float32)
     self.add_loss(loss)
     return original_inputs
示例#2
0
def TransformedRandomVariable(
        rv,  # pylint: disable=invalid-name
        reversible_layer,
        name=None,
        sample_shape=(),
        value=None):
    """Random variable for f(x), where x ~ p(x) and f is reversible."""
    return random_variable.RandomVariable(distribution=TransformedDistribution(
        rv.distribution, reversible_layer, name=name),
                                          sample_shape=sample_shape,
                                          value=value)