예제 #1
0
    def test_construction(self):
        vi = VariationalInference(T.float_scalar(1.), T.float_scalar(2.))
        self.assertIsNone(vi.axis)
        self.assertIsInstance(vi.training, VariationalTrainingObjectives)
        self.assertIsInstance(vi.lower_bound, VariationalLowerBounds)
        self.assertIsInstance(vi.evaluation, VariationalEvaluation)

        assert_equal(vi.log_joint, 1.)
        assert_equal(vi.latent_log_joint, 2.)
예제 #2
0
    def test_log_joint_arg(self):
        p_log_probs, p, q_log_probs, q = self.prepare_model()

        chain = VariationalChain(p,
                                 q,
                                 log_joint=T.float_scalar(-1.),
                                 latent_log_joint=T.float_scalar(-2.))
        assert_allclose(chain.log_joint, -1.)
        assert_allclose(chain.latent_log_joint, -2.)
        assert_allclose(chain.vi.log_joint, -1.)
        assert_allclose(chain.vi.latent_log_joint, -2.)

        self.assertFalse(p_log_probs.called)
        self.assertFalse(q_log_probs.called)
예제 #3
0
 def test_errors(self):
     # test no sampling axis should cause errors
     vi = VariationalInference(
         T.float_scalar(0.), T.float_scalar(0.), axis=None)
     with pytest.raises(
             Exception, match='`monte_carlo_objective` requires to take '
                              'multiple samples'):
         _ = vi.lower_bound.monte_carlo_objective()
     with pytest.raises(
             Exception, match='`iwae_estimator` requires to take multiple '
                              'samples'):
         _ = vi.training.iwae()
     with pytest.raises(
             Exception, match='`importance_sampling_log_likelihood` '
                              'requires to take[^@]*multiple samples'):
         _ = vi.evaluation.importance_sampling_log_likelihood()
예제 #4
0
    def test_chain(self):
        q_net = BayesianNet({'x': T.ones([1])})
        q_net.add('z', Normal(q_net.observed['x'], T.float_scalar(1.)))
        q_net.add('y', Normal(q_net.observed['x'] * 2, T.float_scalar(2.)))

        def net_builder(observed):
            net = BayesianNet(observed)
            z = net.add('z', UnitNormal([1]))
            y = net.add('y', Normal(T.zeros([1]), T.full([1], 2.)))
            x = net.add('x', Normal(z.tensor + y.tensor, T.ones([1])))
            return net

        net_builder = mock.Mock(wraps=net_builder)

        # test chain with default parameters
        chain = q_net.chain(net_builder)
        self.assertEqual(net_builder.call_args, (({
            'y': q_net['y'],
            'z': q_net['z']
        }, ), ))
        self.assertEqual(chain.latent_names, ['z', 'y'])
        self.assertIsNone(chain.latent_axis)

        # test chain with latent_names
        chain = q_net.chain(net_builder, latent_names=['y'])
        self.assertEqual(net_builder.call_args, (({'y': q_net['y']}, ), ))
        self.assertEqual(chain.latent_names, ['y'])

        # test chain with latent_axis
        chain = q_net.chain(net_builder, latent_axis=-1)
        self.assertEqual(chain.latent_axis, [-1])

        chain = q_net.chain(net_builder, latent_axis=[-1, 2])
        self.assertEqual(chain.latent_axis, [-1, 2])

        # test chain with observed
        chain = q_net.chain(net_builder, observed=q_net.observed)
        self.assertEqual(net_builder.call_args, (({
            'x': q_net.observed['x'],
            'y': q_net['y'],
            'z': q_net['z']
        }, ), ))
        self.assertEqual(chain.latent_names, ['z', 'y'])
예제 #5
0
        def do_test(low, high, dtype):
            # test(n_samples=n_samples)
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.truncated_normal(mean_t,
                                          std_t,
                                          n_samples=n_samples,
                                          low=low,
                                          high=high)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample value range
            x = T.to_numpy(t)
            if low is not None:
                np.testing.assert_array_less(
                    (low * std + mean - 1e-7) * np.ones_like(x), x)
            if high is not None:
                np.testing.assert_array_less(
                    x,
                    np.ones_like(x) * high * std + mean + 1e-7)

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(
                                  T.random.truncated_normal_log_pdf,
                                  mean=mean_t,
                                  std=std_t,
                                  logstd=logstd_t,
                                  low=low,
                                  high=high,
                                  log_zero=log_zero,
                              ),
                              np_log_prob=log_prob(x, low, high))
            do_check_log_prob(
                given=t *
                10.,  # where the majority is out of [low, high] range
                batch_ndims=len(x.shape),
                Z_log_prob_fn=partial(
                    T.random.truncated_normal_log_pdf,
                    mean=mean_t,
                    std=std_t,
                    logstd=logstd_t,
                    low=low,
                    high=high,
                    log_zero=log_zero,
                ),
                np_log_prob=log_prob(x * 10., low, high))

            # test(n_samples=None)
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.truncated_normal(mean_t, std_t, low=low, high=high)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())

            # test sample value range
            x = T.to_numpy(t)
            if low is not None:
                np.testing.assert_array_less(
                    (low * std + mean - 1e-7) * np.ones_like(x), x)
            if high is not None:
                np.testing.assert_array_less(
                    x,
                    np.ones_like(x) * high * std + mean + 1e-7)

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(
                                  T.random.truncated_normal_log_pdf,
                                  mean=mean_t,
                                  std=std_t,
                                  logstd=logstd_t,
                                  low=low,
                                  high=high,
                                  log_zero=log_zero,
                              ),
                              np_log_prob=log_prob(x, low, high))
            do_check_log_prob(
                given=t *
                10.,  # where the majority is out of [low, high] range
                batch_ndims=len(x.shape),
                Z_log_prob_fn=partial(
                    T.random.truncated_normal_log_pdf,
                    mean=mean_t,
                    std=std_t,
                    logstd=logstd_t,
                    low=low,
                    high=high,
                    log_zero=log_zero,
                ),
                np_log_prob=log_prob(x * 10., low, high))

            # test reparameterized
            w = np.random.randn(2, 3, 4)

            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.truncated_normal(mean_t, std_t)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)])
            assert_allclose(mean_grad, w, rtol=1e-4)
            assert_allclose(std_grad,
                            np.sum(T.to_numpy((t - w_t * mean_t) / std_t),
                                   axis=0),
                            rtol=1e-4)

            # test not reparameterized
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.truncated_normal(
                mean_t, std_t, reparameterized=False)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)],
                                           allow_unused=True)
            self.assertTrue(T.is_null_grad(mean_t, mean_grad))
            self.assertTrue(T.is_null_grad(std_t, std_grad))

            # given has lower rank than params, broadcasted to match param
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            assert_allclose(T.random.truncated_normal_log_pdf(
                T.float_scalar(0.),
                mean_t,
                std_t,
                logstd_t,
                low=low,
                high=high,
                log_zero=log_zero),
                            log_prob(0., low=low, high=high),
                            rtol=1e-4)

            # dtype mismatch
            with pytest.raises(Exception, match='`mean.dtype` != `std.dtype`'):
                _ = T.random.truncated_normal(T.as_tensor(mean, T.float32),
                                              T.as_tensor(std, T.float64),
                                              low=low,
                                              high=high)

            # check numerics
            mean_t = T.as_tensor(mean)
            std_t = T.zeros_like(mean_t)
            logstd_t = T.as_tensor(T.log(std_t))
            t = T.random.normal(mean_t, std_t)
            with pytest.raises(Exception,
                               match='Infinity or NaN value encountered'):
                _ = T.random.truncated_normal_log_pdf(t,
                                                      mean_t,
                                                      std_t,
                                                      logstd_t,
                                                      validate_tensors=True)
예제 #6
0
    def test_normal(self):
        mean = np.random.randn(2, 3, 4)
        logstd = np.random.randn(3, 4)
        std = np.exp(logstd)

        def log_prob(given):
            # np.log(np.exp(-(given - mean) ** 2 / (2. * std ** 2)) /
            #        (np.sqrt(2 * np.pi) * std))
            return (-(given - mean)**2 * (0.5 * np.exp(-2. * logstd)) -
                    np.log(np.sqrt(2 * np.pi)) - logstd)

        # test n_samples by manual expanding the param shape
        for dtype in float_dtypes:
            # test sample dtype and shape
            mean_t = T.cast(T.expand(T.as_tensor(mean), [n_samples, 2, 3, 4]),
                            dtype)
            std_t = T.cast(T.expand(T.as_tensor(std), [n_samples, 1, 3, 4]),
                           dtype)
            logstd_t = T.cast(
                T.expand(T.as_tensor(logstd), [n_samples, 1, 3, 4]), dtype)
            t = T.random.normal(mean_t, std_t)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample mean
            x = T.to_numpy(t)
            x_mean = np.mean(x, axis=0)
            np.testing.assert_array_less(
                np.abs(x_mean - mean),
                np.tile(np.expand_dims(5 * std / np.sqrt(n_samples), axis=0),
                        [2, 1, 1]))

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test with n_samples
        for dtype in float_dtypes:
            # test sample dtype and shape
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.normal(mean_t, std_t, n_samples=n_samples)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample mean
            x = T.to_numpy(t)
            x_mean = np.mean(x, axis=0)
            np.testing.assert_array_less(
                np.abs(x_mean - mean),
                np.tile(np.expand_dims(5 * std / np.sqrt(n_samples), axis=0),
                        [2, 1, 1]))

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test no n_samples
        for dtype in float_dtypes:
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.normal(mean_t, std_t)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())

            # test log_prob
            x = T.to_numpy(t)
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test reparameterized
        w = np.random.randn(2, 3, 4)

        for dtype in float_dtypes:
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.normal(mean_t, std_t)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)])
            assert_allclose(mean_grad, w, rtol=1e-4)
            assert_allclose(std_grad,
                            np.sum(T.to_numpy((t - w_t * mean_t) / std_t),
                                   axis=0),
                            rtol=1e-4)

        # test not reparameterized
        for dtype in float_dtypes:
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.normal(mean_t, std_t, reparameterized=False)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)],
                                           allow_unused=True)
            self.assertTrue(T.is_null_grad(mean_t, mean_grad))
            self.assertTrue(T.is_null_grad(std_t, std_grad))

        # given has lower rank than params, broadcasted to match param
        for dtype in float_dtypes:
            mean_t = T.as_tensor(mean, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            for val in (0., 1., -1.):
                assert_allclose(T.random.normal_log_pdf(
                    T.float_scalar(val), mean_t, logstd_t),
                                log_prob(val),
                                rtol=1e-4)

        # dtype mismatch
        with pytest.raises(Exception, match='`mean.dtype` != `std.dtype`'):
            _ = T.random.normal(T.as_tensor(mean, T.float32),
                                T.as_tensor(std, T.float64))

        # check numerics
        mean_t = T.as_tensor(mean)
        std_t = T.zeros_like(mean_t)
        logstd_t = T.as_tensor(T.log(std_t))
        t = T.random.normal(mean_t, std_t)
        with pytest.raises(Exception,
                           match='Infinity or NaN value encountered'):
            _ = T.random.normal_log_pdf(t,
                                        mean_t,
                                        logstd_t,
                                        validate_tensors=True)
예제 #7
0
 def q_log_probs(names):
     log_probs = {'a': 1., 'b': 2.}
     return [T.float_scalar(log_probs[n]) for n in names]
예제 #8
0
 def p_log_probs(names):
     log_probs = {'c': 3., 'd': 4.}
     return [T.float_scalar(log_probs[n]) for n in names]