Esempio n. 1
0
    def test_normal_leaf_layer(self):
        """Test the normal leaf layer."""
        # Setup leaf layer
        out_channels = 7
        in_features = 8
        num_repetitions = 5
        leaf = distributions.Normal(out_channels=out_channels,
                                    in_features=in_features,
                                    num_repetitions=num_repetitions)

        # Setup test input
        batch_size = 3
        x = torch.rand(size=(batch_size, in_features))

        # Setup artificial means and scale matrices
        means = torch.randn(1, in_features, out_channels, num_repetitions)
        scale = torch.rand(1, in_features, out_channels, num_repetitions)

        # Use scipy norm to get pdfs
        # Expected result
        expected_result = torch.zeros(batch_size, in_features, out_channels,
                                      num_repetitions)

        # Repetition 1
        for n in range(batch_size):
            for d in range(in_features):
                for c in range(out_channels):
                    for r in range(num_repetitions):
                        expected_result[n, d, c, r] = TorchNormal(
                            loc=means[0, d, c, r],
                            scale=scale[0, d, c, r]).log_prob(x[n, d])

        # Perform forward pass in leaf
        leaf.means.data = means
        leaf.stds.data = scale
        result = leaf(x)

        # Make assertions
        self.assertEqual(result.shape[0], batch_size)
        self.assertEqual(result.shape[1], in_features)
        self.assertEqual(result.shape[2], out_channels)
        self.assertTrue(((result - expected_result).abs() < 1e-6).all())
Esempio n. 2
0
 def __init__(self, loc, scale_diag, reinterpreted_batch_ndims=1):
     dist = Independent(TorchNormal(loc, scale_diag),
                        reinterpreted_batch_ndims=reinterpreted_batch_ndims)
     super().__init__(dist)