def test_normal_leaf_layer(self): """Test the normal leaf layer.""" # Setup leaf layer out_channels = 7 in_features = 8 num_repetitions = 5 leaf = distributions.Normal(out_channels=out_channels, in_features=in_features, num_repetitions=num_repetitions) # Setup test input batch_size = 3 x = torch.rand(size=(batch_size, in_features)) # Setup artificial means and scale matrices means = torch.randn(1, in_features, out_channels, num_repetitions) scale = torch.rand(1, in_features, out_channels, num_repetitions) # Use scipy norm to get pdfs # Expected result expected_result = torch.zeros(batch_size, in_features, out_channels, num_repetitions) # Repetition 1 for n in range(batch_size): for d in range(in_features): for c in range(out_channels): for r in range(num_repetitions): expected_result[n, d, c, r] = TorchNormal( loc=means[0, d, c, r], scale=scale[0, d, c, r]).log_prob(x[n, d]) # Perform forward pass in leaf leaf.means.data = means leaf.stds.data = scale result = leaf(x) # Make assertions self.assertEqual(result.shape[0], batch_size) self.assertEqual(result.shape[1], in_features) self.assertEqual(result.shape[2], out_channels) self.assertTrue(((result - expected_result).abs() < 1e-6).all())
def __init__(self, loc, scale_diag, reinterpreted_batch_ndims=1): dist = Independent(TorchNormal(loc, scale_diag), reinterpreted_batch_ndims=reinterpreted_batch_ndims) super().__init__(dist)