def test_input_extra_var(self): q = Normal(var=['z'], cond_var=['x'], loc='x', scale=1) p = Normal(var=['y'], cond_var=['z'], loc='z', scale=1) e = Expectation(q, p.log_prob()) assert set(e.eval({'y': torch.zeros(1), 'x': torch.zeros(1), 'w': torch.zeros(1)}, return_dict=True)[1]) == set(('w', 'x', 'y', 'z')) assert set(e.eval({'y': torch.zeros(1), 'x': torch.zeros(1), 'z': torch.zeros(1)}, return_dict=True)[1]) == set(('x', 'y', 'z'))
def test_input_var(self): q = Normal(var=['z'], cond_var=['x'], loc='x', scale=1) p = Normal(var=['y'], cond_var=['z'], loc='z', scale=1) e = Expectation(q, p.log_prob()) assert set(e.input_var) == set(('x', 'y')) assert e.eval({'y': torch.zeros(1), 'x': torch.zeros(1)}).shape == torch.Size([1])
def test_sample_mean(self): p = Normal(loc=0, scale=1) f = p.log_prob() e = Expectation(p, f) e.eval({}, sample_mean=True)