Exemplo n.º 1
0
 def test_input_extra_var(self):
     normal = Normal(loc=0, scale=1)
     assert set(normal.sample({'y': torch.zeros(1)})) == set(('x', 'y'))
     assert normal.get_log_prob({
         'y': torch.zeros(1),
         'x': torch.zeros(1)
     }).shape == torch.Size([1])
     assert set(normal.sample({'x': torch.zeros(1)})) == set(('x'))
Exemplo n.º 2
0
 def test_get_log_prob_feature_dims2(self):
     dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal(
         var=['y'], loc=0, scale=1)
     dist.graph.set_option(dict(batch_n=4, sample_shape=(2, 3)), ['y'])
     sample = dist.sample()
     assert sample['y'].shape == torch.Size([2, 3, 4])
     list(dist.graph._factors_from_variable('y'))[0].option = {}
     assert dist.get_log_prob(sample, sum_features=True,
                              feature_dims=None).shape == torch.Size([2])
     assert dist.get_log_prob(sample, sum_features=True,
                              feature_dims=[-2]).shape == torch.Size([2, 4])
     assert dist.get_log_prob(sample,
                              sum_features=True,
                              feature_dims=[0, 1]).shape == torch.Size([4])
     assert dist.get_log_prob(sample, sum_features=True,
                              feature_dims=[]).shape == torch.Size(
                                  [2, 3, 4])
Exemplo n.º 3
0
    def test_set_option(self):
        dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal(
            var=['y'], loc=0, scale=1)
        dist.graph.set_option(dict(batch_n=4, sample_shape=(2, 3)), ['y'])
        sample = dist.sample()
        assert sample['y'].shape == torch.Size([2, 3, 4])
        assert sample['x'].shape == torch.Size([2, 3, 4])
        dist.graph.set_option({}, ['y'])
        assert dist.get_log_prob(sample, sum_features=True,
                                 feature_dims=None).shape == torch.Size([2])
        assert dist.get_log_prob(
            sample, sum_features=False).shape == torch.Size([2, 3, 4])

        dist = Normal(var=['x'], cond_var=['y'], loc='y',
                      scale=1) * FactorizedBernoulli(
                          var=['y'], probs=torch.tensor([0.3, 0.8]))
        dist.graph.set_option(dict(batch_n=3, sample_shape=(4, )), ['y'])
        sample = dist.sample()
        assert sample['y'].shape == torch.Size([4, 3, 2])
        assert sample['x'].shape == torch.Size([4, 3, 2])
        dist.graph.set_option(dict(), ['y'])
        assert dist.get_log_prob(sample, sum_features=True,
                                 feature_dims=[-1]).shape == torch.Size([4, 3])