def test_input_extra_var(self): normal = Normal(loc=0, scale=1) assert set(normal.sample({'y': torch.zeros(1)})) == set(('x', 'y')) assert normal.get_log_prob({ 'y': torch.zeros(1), 'x': torch.zeros(1) }).shape == torch.Size([1]) assert set(normal.sample({'x': torch.zeros(1)})) == set(('x'))
def test_get_log_prob_feature_dims2(self): dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal( var=['y'], loc=0, scale=1) dist.graph.set_option(dict(batch_n=4, sample_shape=(2, 3)), ['y']) sample = dist.sample() assert sample['y'].shape == torch.Size([2, 3, 4]) list(dist.graph._factors_from_variable('y'))[0].option = {} assert dist.get_log_prob(sample, sum_features=True, feature_dims=None).shape == torch.Size([2]) assert dist.get_log_prob(sample, sum_features=True, feature_dims=[-2]).shape == torch.Size([2, 4]) assert dist.get_log_prob(sample, sum_features=True, feature_dims=[0, 1]).shape == torch.Size([4]) assert dist.get_log_prob(sample, sum_features=True, feature_dims=[]).shape == torch.Size( [2, 3, 4])
def test_set_option(self): dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal( var=['y'], loc=0, scale=1) dist.graph.set_option(dict(batch_n=4, sample_shape=(2, 3)), ['y']) sample = dist.sample() assert sample['y'].shape == torch.Size([2, 3, 4]) assert sample['x'].shape == torch.Size([2, 3, 4]) dist.graph.set_option({}, ['y']) assert dist.get_log_prob(sample, sum_features=True, feature_dims=None).shape == torch.Size([2]) assert dist.get_log_prob( sample, sum_features=False).shape == torch.Size([2, 3, 4]) dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * FactorizedBernoulli( var=['y'], probs=torch.tensor([0.3, 0.8])) dist.graph.set_option(dict(batch_n=3, sample_shape=(4, )), ['y']) sample = dist.sample() assert sample['y'].shape == torch.Size([4, 3, 2]) assert sample['x'].shape == torch.Size([4, 3, 2]) dist.graph.set_option(dict(), ['y']) assert dist.get_log_prob(sample, sum_features=True, feature_dims=[-1]).shape == torch.Size([4, 3])