Beispiel #1
0
    def test_spn_sampling(self):

        # Define SPN
        leaf = distributions.Normal(in_features=2 ** 3, out_channels=5, num_repetitions=1)
        sum_1 = layers.Sum(in_channels=5, in_features=2 ** 3, out_channels=20, num_repetitions=1)
        prd_1 = layers.Product(in_features=2 ** 3, cardinality=2, num_repetitions=1)
        sum_2 = layers.Sum(in_channels=20, in_features=2 ** 2, out_channels=20, num_repetitions=1)
        prd_2 = layers.Product(in_features=2 ** 2, cardinality=2, num_repetitions=1)
        sum_3 = layers.Sum(in_channels=20, in_features=2 ** 1, out_channels=20, num_repetitions=1)
        prd_3 = layers.Product(in_features=2 ** 1, cardinality=2, num_repetitions=1)
        sum_4 = layers.Sum(in_channels=20, in_features=2 ** 0, out_channels=1, num_repetitions=1)

        # Test forward pass
        x_test = torch.randn(1, 2 ** 3)

        x_test = leaf(x_test)
        x_test = sum_1(x_test)
        x_test = prd_1(x_test)
        x_test = sum_2(x_test)
        x_test = prd_2(x_test)
        x_test = sum_3(x_test)
        x_test = prd_3(x_test)
        res = sum_4(x_test)

        # Sampling pass
        ctx = sum_4.sample(n=1000)
        prd_3.sample(context=ctx)
        sum_3.sample(context=ctx)
        prd_2.sample(context=ctx)
        sum_2.sample(context=ctx)
        prd_1.sample(context=ctx)
        sum_1.sample(context=ctx)
        samples = leaf.sample(context=ctx)
Beispiel #2
0
    def test_sum_as_intermediate_node(self):
        """Check that sum node returns the correct sample indices when used as indermediate node."""
        # Some values for the sum layer
        in_features = 10
        in_channels = 3
        out_channels = 5
        num_repetitions = 7
        n = 2
        parent_indices = torch.randint(out_channels, size=(n, in_features))

        # Create sum layer
        sum_layer = layers.Sum(
            in_features=in_features, in_channels=in_channels, out_channels=out_channels, num_repetitions=num_repetitions
        )

        # Choose `in_features` number of random indexes from 0 to in_channels-1 which will have probability of 1.0 in
        # the sum layer weight tensor
        rand_indxs = torch.randint(in_channels, size=(in_features, num_repetitions))
        rep_idxs = torch.randint(num_repetitions, size=(n,))

        # Artificially set sum weights (probabilities) to 1.0
        weights = torch.zeros(in_features, in_channels, out_channels, num_repetitions)
        for r in range(num_repetitions):
            weights[range(in_features), rand_indxs[:, r], :, r] = 1.0
        sum_layer.weights = nn.Parameter(torch.log(weights))

        # Perform sampling
        ctx = SamplingContext(n=n, parent_indices=parent_indices, repetition_indices=rep_idxs)
        sum_layer.sample(context=ctx)

        # Assert that the sample indexes are those where the weights were set to 1.0
        for i in range(n):
            self.assertTrue((rand_indxs[:, rep_idxs[i]] == ctx.parent_indices[i, :]).all())
def create_pytorch_spn(n_feats):
    # Create SPN layers
    gauss = Normal(multiplicity=2, in_features=n_feats, in_channels=1)
    prod1 = layers.Product(in_features=n_feats, cardinality=2)
    sum1 = layers.Sum(in_features=n_feats / 2, in_channels=2, out_channels=1)
    prod2 = layers.Product(in_features=n_feats / 2, cardinality=n_feats // 2)

    # Stack SPN layers
    device = torch.device("cuda:0")
    pytorch_spn = nn.Sequential(gauss, prod1, sum1, prod2).to(device)
    return pytorch_spn
Beispiel #4
0
 def test_sum_shape_as_root_node(self):
     """Check that the sum node has the correct sampling shape when used as root."""
     n = 5
     num_repetitions = 1
     for in_channels in [1, 5, 10]:
         for in_features in [1, 5, 10]:
             sum_layer = layers.Sum(
                 in_channels=in_channels, out_channels=1, in_features=in_features, num_repetitions=num_repetitions
             )
             ctx = sum_layer.sample(n=n)
             self.assertTrue(ctx.parent_indices.shape[0] == n)
             self.assertTrue(ctx.parent_indices.shape[1] == in_features)
Beispiel #5
0
    def test_sum_layer(self):
        """Test the forward pass of a sum layer"""

        # Setup layer
        in_channels = 8
        out_channels = 7
        in_features = 3
        num_repetitions = 5
        sum_layer = layers.Sum(in_channels=in_channels,
                               out_channels=out_channels,
                               in_features=in_features,
                               num_repetitions=num_repetitions)

        w = torch.rand(in_features, in_channels, out_channels, num_repetitions)

        # Set the sum layer parameters
        sum_layer.weights = nn.Parameter(w)

        # Apply softmax once again since Sum forward pass uses F.log_softmax internally to project random weights
        # back into valid ranges
        w = F.softmax(w, dim=1)

        # Setup test input
        batch_size = 16
        x = torch.rand(size=(batch_size, in_features, in_channels,
                             num_repetitions))

        # Expected outcome
        expected_result = torch.zeros(batch_size, in_features, out_channels,
                                      num_repetitions)
        for n in range(batch_size):
            for d in range(in_features):
                for oc in range(out_channels):
                    for r in range(num_repetitions):
                        expected_result[n, d, oc,
                                        r] = x[n, d, :, r] @ w[d, :, oc, r]

        # Do forward pass: apply log as sum layer operates in log space. Exp() afterwards to make it comparable to the
        # expected result
        result = sum_layer(x.log()).exp()

        # Run assertions
        self.assertTrue(result.shape[0] == batch_size)
        self.assertTrue(result.shape[1] == in_features)
        self.assertTrue(result.shape[2] == out_channels)
        self.assertTrue(result.shape[3] == num_repetitions)
        self.assertTrue(((result - expected_result).abs() < 1e-6).all())
Beispiel #6
0
    def test_spn_mpe(self):

        # Define SPN
        leaf = distributions.Normal(in_features=2**3,
                                    out_channels=5,
                                    num_repetitions=1)
        sum_1 = layers.Sum(in_channels=5,
                           in_features=2**3,
                           out_channels=20,
                           num_repetitions=1)
        prd_1 = layers.Product(in_features=2**3,
                               cardinality=2,
                               num_repetitions=1)
        sum_2 = layers.Sum(in_channels=20,
                           in_features=2**2,
                           out_channels=20,
                           num_repetitions=1)
        prd_2 = layers.Product(in_features=2**2,
                               cardinality=2,
                               num_repetitions=1)
        sum_3 = layers.Sum(in_channels=20,
                           in_features=2**1,
                           out_channels=20,
                           num_repetitions=1)
        prd_3 = layers.Product(in_features=2**1,
                               cardinality=2,
                               num_repetitions=1)
        sum_4 = layers.Sum(in_channels=20,
                           in_features=2**0,
                           out_channels=1,
                           num_repetitions=1)

        sum_1._enable_input_cache()
        sum_2._enable_input_cache()
        sum_3._enable_input_cache()
        sum_4._enable_input_cache()

        # Test forward pass
        x_test = torch.randn(1, 2**3)

        x_test = leaf(x_test)
        x_test = sum_1(x_test)
        x_test = prd_1(x_test)
        x_test = sum_2(x_test)
        x_test = prd_2(x_test)
        x_test = sum_3(x_test)
        x_test = prd_3(x_test)
        res = sum_4(x_test)

        ctx = SamplingContext(n=x_test.shape[0], is_mpe=True)
        sum_4.sample(context=ctx)
        prd_3.sample(context=ctx)
        sum_3.sample(context=ctx)
        prd_2.sample(context=ctx)
        sum_2.sample(context=ctx)
        prd_1.sample(context=ctx)
        sum_1.sample(context=ctx)

        # Should be the same
        mpe_1 = leaf.sample(context=ctx)
        mpe_2 = leaf.sample(context=ctx)
        mpe_3 = leaf.sample(context=ctx)
        self.assertTrue(((mpe_1 - mpe_2).abs() < 1e-6).all())
        self.assertTrue(((mpe_2 - mpe_3).abs() < 1e-6).all())