Example #1
0
    def testSingleTensor(self, bijector, dtype):
        if not tf.executing_eagerly():
            return
        base_mean = tf.convert_to_tensor(value=[1., 0], dtype=dtype)
        base_cov = tf.convert_to_tensor(value=[[1, 0.5], [0.5, 1]],
                                        dtype=dtype)

        base_dist = tfd.MultivariateNormalFullCovariance(
            loc=base_mean, covariance_matrix=base_cov)
        target_dist = bijector(base_dist)

        def debug_fn(*args):
            del args
            debug_fn.count += 1

        debug_fn.count = 0

        kernel = neutra.NeuTra(
            target_log_prob_fn=target_dist.log_prob,
            state_shape=2,
            num_step_size_adaptation_steps=800,
            num_train_steps=1000,
            train_batch_size=64,
            learning_rate=tf.convert_to_tensor(value=1e-2, dtype=dtype),
            seed=tfp_test_util.test_seed(),
            train_debug_fn=debug_fn,
            unconstraining_bijector=bijector,
        )

        chain = tfp.mcmc.sample_chain(num_results=1000,
                                      num_burnin_steps=1000,
                                      current_state=tf.zeros([16, 2], dtype),
                                      kernel=kernel,
                                      trace_fn=None,
                                      parallel_iterations=1)
        self.assertEqual(1000, debug_fn.count)

        sample_mean = tf.reduce_mean(input_tensor=chain, axis=[0, 1])
        sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed())

        true_mean = tf.reduce_mean(input_tensor=true_samples, axis=0)
        true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
        self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
Example #2
0
    def testNested(self, bijector):
        if not tf.executing_eagerly():
            return
        base_mean = tf.constant([1., 0])
        base_cov = tf.constant([[1, 0.5], [0.5, 1]])

        dist_2d = tfd.MultivariateNormalFullCovariance(
            loc=base_mean, covariance_matrix=base_cov)
        dist_4d = tfd.MultivariateNormalDiag(scale_diag=tf.ones(4))

        target_dist = tfd.JointDistributionSequential([
            bijector(dist_2d),
            tfb.Reshape([2, 2])(dist_4d),
        ])

        kernel = neutra.NeuTra(
            target_log_prob_fn=lambda x, y: target_dist.log_prob((x, y)),
            state_shape=target_dist.event_shape,
            num_step_size_adaptation_steps=800,
            num_train_steps=1000,
            train_batch_size=64,
            seed=tfp_test_util.test_seed(),
            unconstraining_bijector=[bijector, tfb.Identity()],
        )

        chain_2d, chain_4d = tfp.mcmc.sample_chain(
            num_results=1000,
            num_burnin_steps=1000,
            current_state=tf.nest.map_structure(
                lambda s: tf.zeros([16] + s.as_list()),
                target_dist.event_shape),
            kernel=kernel,
            trace_fn=None,
            parallel_iterations=1)

        sample_mean_2d = tf.reduce_mean(input_tensor=chain_2d, axis=[0, 1])
        sample_mean_4d = tf.reduce_mean(input_tensor=chain_4d, axis=[0, 1])

        true_samples_2d, true_samples_4d = target_dist.sample(
            4096, seed=tfp_test_util.test_seed())

        true_mean_2d = tf.reduce_mean(input_tensor=true_samples_2d, axis=0)
        true_mean_4d = tf.reduce_mean(input_tensor=true_samples_4d, axis=0)

        self.assertAllClose(true_mean_2d, sample_mean_2d, rtol=0.1, atol=0.1)
        self.assertAllClose(true_mean_4d, sample_mean_4d, rtol=0.1, atol=0.1)