Beispiel #1
0
 def testInvertMutuallyConsistent(self):
     dims = 4
     with self.cached_session() as sess:
         ma = Invert(
             MaskedAutoregressiveFlow(validate_args=True,
                                      **self._autoregressive_flow_kwargs))
         dist = transformed_distribution_lib.TransformedDistribution(
             distribution=normal_lib.Normal(loc=0., scale=1.),
             bijector=ma,
             event_shape=[dims],
             validate_args=True)
         self.run_test_sample_consistent_log_prob(sess_run_fn=sess.run,
                                                  dist=dist,
                                                  num_samples=int(1e5),
                                                  radius=1.,
                                                  center=0.,
                                                  rtol=0.02)
 def testInvertMutuallyConsistent(self):
     dims = 4
     with self.test_session() as sess:
         nvp = Invert(
             RealNVP(num_masked=3,
                     validate_args=True,
                     **self._real_nvp_kwargs))
         dist = transformed_distribution_lib.TransformedDistribution(
             distribution=normal_lib.Normal(loc=0., scale=1.),
             bijector=nvp,
             event_shape=[dims],
             validate_args=True)
         self.run_test_sample_consistent_log_prob(sess_run_fn=sess.run,
                                                  dist=dist,
                                                  num_samples=int(1e5),
                                                  radius=1.,
                                                  center=0.,
                                                  rtol=0.02)
 def testInvertMutuallyConsistent(self):
     # BatchNorm bijector is only mutually consistent when training=False.
     dims = 4
     with self.cached_session() as sess:
         layer = normalization.BatchNormalization(epsilon=0.)
         batch_norm = Invert(
             BatchNormalization(batchnorm_layer=layer, training=False))
         dist = transformed_distribution_lib.TransformedDistribution(
             distribution=normal_lib.Normal(loc=0., scale=1.),
             bijector=batch_norm,
             event_shape=[dims],
             validate_args=True)
         self.run_test_sample_consistent_log_prob(sess_run_fn=sess.run,
                                                  dist=dist,
                                                  num_samples=int(1e5),
                                                  radius=2.,
                                                  center=0.,
                                                  rtol=0.02)