コード例 #1
0
ファイル: test_iteration.py プロジェクト: masa-su/pixyz
 def test_input_extra_var(self):
     q = Normal(var=['z'], cond_var=['x'], loc='x', scale=1)
     p = Normal(var=['y'], cond_var=['z'], loc='z', scale=1)
     e = Expectation(q, p.log_prob())
     assert set(e.eval({'y': torch.zeros(1), 'x': torch.zeros(1),
                        'w': torch.zeros(1)}, return_dict=True)[1]) == set(('w', 'x', 'y', 'z'))
     assert set(e.eval({'y': torch.zeros(1), 'x': torch.zeros(1),
                        'z': torch.zeros(1)}, return_dict=True)[1]) == set(('x', 'y', 'z'))
コード例 #2
0
ファイル: test_iteration.py プロジェクト: masa-su/pixyz
 def test_input_var(self):
     q = Normal(var=['z'], cond_var=['x'], loc='x', scale=1)
     p = Normal(var=['y'], cond_var=['z'], loc='z', scale=1)
     e = Expectation(q, p.log_prob())
     assert set(e.input_var) == set(('x', 'y'))
     assert e.eval({'y': torch.zeros(1), 'x': torch.zeros(1)}).shape == torch.Size([1])
コード例 #3
0
 def test_sample_mean(self):
     p = Normal(loc=0, scale=1)
     f = p.log_prob()
     e = Expectation(p, f)
     e.eval({}, sample_mean=True)