示例#1
0
 def testLogProb(self, event_shape, event_dims, training):
     with self.test_session() as sess:
         training = tf.placeholder_with_default(training, (), "training")
         layer = normalization.BatchNormalization(axis=event_dims,
                                                  epsilon=0.)
         batch_norm = tfb.BatchNormalization(batchnorm_layer=layer,
                                             training=training)
         base_dist = distributions.MultivariateNormalDiag(
             loc=np.zeros(np.prod(event_shape), dtype=np.float32))
         # Reshape the events.
         if isinstance(event_shape, int):
             event_shape = [event_shape]
         base_dist = transformed_distribution_lib.TransformedDistribution(
             distribution=base_dist,
             bijector=tfb.Reshape(event_shape_out=event_shape))
         dist = transformed_distribution_lib.TransformedDistribution(
             distribution=base_dist,
             bijector=batch_norm,
             validate_args=True)
         samples = dist.sample(int(1e5))
         # No volume distortion since training=False, bijector is initialized
         # to the identity transformation.
         base_log_prob = base_dist.log_prob(samples)
         dist_log_prob = dist.log_prob(samples)
         tf.global_variables_initializer().run()
         base_log_prob_, dist_log_prob_ = sess.run(
             [base_log_prob, dist_log_prob])
         self.assertAllClose(base_log_prob_, dist_log_prob_)
 def testMaximumLikelihoodTraining(self):
     # Test Maximum Likelihood training with default bijector.
     with self.cached_session() as sess:
         base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.])
         batch_norm = BatchNormalization(training=True)
         dist = transformed_distribution_lib.TransformedDistribution(
             distribution=base_dist, bijector=batch_norm)
         target_dist = distributions.MultivariateNormalDiag(loc=[1., 2.])
         target_samples = target_dist.sample(100)
         dist_samples = dist.sample(3000)
         loss = -math_ops.reduce_mean(dist.log_prob(target_samples))
         with ops.control_dependencies(batch_norm.batchnorm.updates):
             train_op = adam.AdamOptimizer(1e-2).minimize(loss)
             moving_mean = array_ops.identity(
                 batch_norm.batchnorm.moving_mean)
             moving_var = array_ops.identity(
                 batch_norm.batchnorm.moving_variance)
         variables.global_variables_initializer().run()
         for _ in range(3000):
             sess.run(train_op)
         [dist_samples_, moving_mean_,
          moving_var_] = sess.run([dist_samples, moving_mean, moving_var])
         self.assertAllClose([1., 2.],
                             np.mean(dist_samples_, axis=0),
                             atol=5e-2)
         self.assertAllClose([1., 2.], moving_mean_, atol=5e-2)
         self.assertAllClose([1., 1.], moving_var_, atol=5e-2)
示例#3
0
 def testCompareToBijector(self):
     """Demonstrates equivalence between TD, Bijector approach and AR dist."""
     sample_shape = np.int32([4, 5])
     batch_shape = np.int32([])
     event_size = np.int32(2)
     with self.cached_session() as sess:
         batch_event_shape = np.concatenate([batch_shape, [event_size]],
                                            axis=0)
         sample0 = array_ops.zeros(batch_event_shape)
         affine = Affine(scale_tril=self._random_scale_tril(event_size))
         ar = autoregressive_lib.Autoregressive(self._normal_fn(affine),
                                                sample0,
                                                validate_args=True)
         ar_flow = MaskedAutoregressiveFlow(is_constant_jacobian=True,
                                            shift_and_log_scale_fn=lambda x:
                                            [None, affine.forward(x)],
                                            validate_args=True)
         td = transformed_distribution_lib.TransformedDistribution(
             distribution=normal_lib.Normal(loc=0., scale=1.),
             bijector=ar_flow,
             event_shape=[event_size],
             batch_shape=batch_shape,
             validate_args=True)
         x_shape = np.concatenate([sample_shape, batch_shape, [event_size]],
                                  axis=0)
         x = 2. * self._rng.random_sample(x_shape).astype(np.float32) - 1.
         td_log_prob_, ar_log_prob_ = sess.run(
             [td.log_prob(x), ar.log_prob(x)])
         self.assertAllClose(td_log_prob_, ar_log_prob_, atol=0., rtol=1e-6)
示例#4
0
 def testDocstringExample(self):
   with self.test_session():
     exp_gamma_distribution = (
         transformed_distribution_lib.TransformedDistribution(
             distribution=gamma_lib.Gamma(concentration=1., rate=2.),
             bijector=tfb.Invert(tfb.Exp())))
     self.assertAllEqual([], tf.shape(exp_gamma_distribution.sample()).eval())
示例#5
0
 def testMutuallyConsistent(self):
     dims = 4
     with self.cached_session() as sess:
         ma = MaskedAutoregressiveFlow(validate_args=True,
                                       **self._autoregressive_flow_kwargs)
         dist = transformed_distribution_lib.TransformedDistribution(
             distribution=normal_lib.Normal(loc=0., scale=1.),
             bijector=ma,
             event_shape=[dims],
             validate_args=True)
         self.run_test_sample_consistent_log_prob(sess_run_fn=sess.run,
                                                  dist=dist,
                                                  num_samples=int(1e5),
                                                  radius=1.,
                                                  center=0.,
                                                  rtol=0.02)
示例#6
0
 def testMutuallyConsistent(self):
   dims = 4
   with self.test_session() as sess:
     nvp = tfb.RealNVP(
         num_masked=3, validate_args=True, **self._real_nvp_kwargs)
     dist = transformed_distribution_lib.TransformedDistribution(
         distribution=tf.distributions.Normal(loc=0., scale=1.),
         bijector=nvp,
         event_shape=[dims],
         validate_args=True)
     self.run_test_sample_consistent_log_prob(
         sess_run_fn=sess.run,
         dist=dist,
         num_samples=int(1e5),
         radius=1.,
         center=0.,
         rtol=0.02)
示例#7
0
 def testLogProb(self):
   with self.test_session() as sess:
     layer = normalization.BatchNormalization(epsilon=0.)
     batch_norm = BatchNormalization(batchnorm_layer=layer, training=False)
     base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.])
     dist = transformed_distribution_lib.TransformedDistribution(
         distribution=base_dist,
         bijector=batch_norm,
         validate_args=True)
     samples = dist.sample(int(1e5))
     # No volume distortion since training=False, bijector is initialized
     # to the identity transformation.
     base_log_prob = base_dist.log_prob(samples)
     dist_log_prob = dist.log_prob(samples)
     variables.global_variables_initializer().run()
     base_log_prob_, dist_log_prob_ = sess.run([base_log_prob, dist_log_prob])
     self.assertAllClose(base_log_prob_, dist_log_prob_)
 def testInvertMutuallyConsistent(self):
     # BatchNorm bijector is only mutually consistent when training=False.
     dims = 4
     with self.cached_session() as sess:
         layer = normalization.BatchNormalization(epsilon=0.)
         batch_norm = Invert(
             BatchNormalization(batchnorm_layer=layer, training=False))
         dist = transformed_distribution_lib.TransformedDistribution(
             distribution=normal_lib.Normal(loc=0., scale=1.),
             bijector=batch_norm,
             event_shape=[dims],
             validate_args=True)
         self.run_test_sample_consistent_log_prob(sess_run_fn=sess.run,
                                                  dist=dist,
                                                  num_samples=int(1e5),
                                                  radius=2.,
                                                  center=0.,
                                                  rtol=0.02)
示例#9
0
def quadrature_scheme_lognormal_quantiles(loc,
                                          scale,
                                          quadrature_size,
                                          validate_args=False,
                                          name=None):
    """Use LogNormal quantiles to form quadrature on positive-reals.

  Args:
    loc: `float`-like (batch of) scalar `Tensor`; the location parameter of
      the LogNormal prior.
    scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
      the LogNormal prior.
    quadrature_size: Python `int` scalar representing the number of quadrature
      points.
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    grid: (Batch of) length-`quadrature_size` vectors representing the
      `log_rate` parameters of a `Poisson`.
    probs: (Batch of) length-`quadrature_size` vectors representing the
      weight associate with each `grid` value.
  """
    with ops.name_scope(name, "quadrature_scheme_lognormal_quantiles",
                        [loc, scale]):
        # Create a LogNormal distribution.
        dist = transformed_lib.TransformedDistribution(
            distribution=normal_lib.Normal(loc=loc, scale=scale),
            bijector=Exp(),
            validate_args=validate_args)
        batch_ndims = dist.batch_shape.ndims
        if batch_ndims is None:
            batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0]

        def _compute_quantiles():
            """Helper to build quantiles."""
            # Omit {0, 1} since they might lead to Inf/NaN.
            zero = array_ops.zeros([], dtype=dist.dtype)
            edges = math_ops.linspace(zero, 1., quadrature_size + 3)[1:-1]
            # Expand edges so its broadcast across batch dims.
            edges = array_ops.reshape(
                edges,
                shape=array_ops.concat(
                    [[-1],
                     array_ops.ones([batch_ndims], dtype=dtypes.int32)],
                    axis=0))
            quantiles = dist.quantile(edges)
            # Cyclically permute left by one.
            perm = array_ops.concat([math_ops.range(1, 1 + batch_ndims), [0]],
                                    axis=0)
            quantiles = array_ops.transpose(quantiles, perm)
            return quantiles

        quantiles = _compute_quantiles()

        # Compute grid as quantile midpoints.
        grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2.
        # Set shape hints.
        grid.set_shape(dist.batch_shape.concatenate([quadrature_size]))

        # By construction probs is constant, i.e., `1 / quadrature_size`. This is
        # important, because non-constant probs leads to non-reparameterizable
        # samples.
        probs = array_ops.fill(dims=[quadrature_size],
                               value=1. /
                               math_ops.cast(quadrature_size, dist.dtype))

        return grid, probs