Ejemplo n.º 1
0
    def test_spn_sampling(self):

        # Define SPN
        leaf = distributions.Normal(in_features=2 ** 3, out_channels=5, num_repetitions=1)
        sum_1 = layers.Sum(in_channels=5, in_features=2 ** 3, out_channels=20, num_repetitions=1)
        prd_1 = layers.Product(in_features=2 ** 3, cardinality=2, num_repetitions=1)
        sum_2 = layers.Sum(in_channels=20, in_features=2 ** 2, out_channels=20, num_repetitions=1)
        prd_2 = layers.Product(in_features=2 ** 2, cardinality=2, num_repetitions=1)
        sum_3 = layers.Sum(in_channels=20, in_features=2 ** 1, out_channels=20, num_repetitions=1)
        prd_3 = layers.Product(in_features=2 ** 1, cardinality=2, num_repetitions=1)
        sum_4 = layers.Sum(in_channels=20, in_features=2 ** 0, out_channels=1, num_repetitions=1)

        # Test forward pass
        x_test = torch.randn(1, 2 ** 3)

        x_test = leaf(x_test)
        x_test = sum_1(x_test)
        x_test = prd_1(x_test)
        x_test = sum_2(x_test)
        x_test = prd_2(x_test)
        x_test = sum_3(x_test)
        x_test = prd_3(x_test)
        res = sum_4(x_test)

        # Sampling pass
        ctx = sum_4.sample(n=1000)
        prd_3.sample(context=ctx)
        sum_3.sample(context=ctx)
        prd_2.sample(context=ctx)
        sum_2.sample(context=ctx)
        prd_1.sample(context=ctx)
        sum_1.sample(context=ctx)
        samples = leaf.sample(context=ctx)
Ejemplo n.º 2
0
def create_pytorch_spn(n_feats):
    # Create SPN layers
    gauss = Normal(multiplicity=2, in_features=n_feats, in_channels=1)
    prod1 = layers.Product(in_features=n_feats, cardinality=2)
    sum1 = layers.Sum(in_features=n_feats / 2, in_channels=2, out_channels=1)
    prod2 = layers.Product(in_features=n_feats / 2, cardinality=n_feats // 2)

    # Stack SPN layers
    device = torch.device("cuda:0")
    pytorch_spn = nn.Sequential(gauss, prod1, sum1, prod2).to(device)
    return pytorch_spn
Ejemplo n.º 3
0
    def test_product_layer(self):
        """Test the product layer forward pass."""

        # Setup product layer
        in_features = 9
        cardinality = 3
        num_repetitions = 5
        prod_layer = layers.Product(in_features=in_features, cardinality=cardinality, num_repetitions=num_repetitions)

        # Setup test input
        batch_size = 16
        in_channels = 3
        x = torch.rand(size=(batch_size, in_features, in_channels, num_repetitions))

        # Expected result:
        expected_result = torch.ones(batch_size, in_features // cardinality, in_channels, num_repetitions)
        for n in range(batch_size):
            for d in range(0, in_features, cardinality):
                for c in range(in_channels):
                    for r in range(num_repetitions):
                        for i in range(cardinality):
                            expected_result[n, d // cardinality, c, r] *= x[n, d + i, c, r]

        # Actual result
        result = prod_layer(x.log()).exp()

        # Run assertions
        self.assertTrue(result.shape[0] == batch_size)
        self.assertTrue(result.shape[1] == in_features // cardinality)
        self.assertTrue(result.shape[2] == in_channels)
        self.assertTrue(result.shape[3] == num_repetitions)
        self.assertTrue(((result - expected_result).abs() < 1e-6).all())
Ejemplo n.º 4
0
 def test_product_shape_as_root_node(self):
     """Check that the product node has the correct sampling shape when used as root."""
     prod_layer = layers.Product(in_features=10,
                                 cardinality=2,
                                 num_repetitions=1)
     ctx = SamplingContext(n=5)
     ctx = prod_layer.sample(context=ctx)
     self.assertTrue(ctx.parent_indices.shape[0] == 5)
     self.assertTrue(ctx.parent_indices.shape[1] == 1)
Ejemplo n.º 5
0
    def test_prod_as_intermediate_node(self):
        # Product layer values
        in_features = 10
        num_samples = 5
        num_repetitions = 5
        for cardinality in range(2, in_features):
            prod_layer = layers.Product(in_features=in_features,
                                        cardinality=cardinality,
                                        num_repetitions=num_repetitions)

            # Example parent indexes
            parent_indices = torch.randint(high=5,
                                           size=(num_samples, in_features))

            # Create expected indexes: each index is repeated #cardinality times
            pad = (cardinality - in_features % cardinality) % cardinality
            expected_sample_indices = []
            for j in range(num_samples):

                sample_i_indices = []
                for i in parent_indices[j, :]:
                    sample_i_indices += [i] * cardinality

                # Remove padding
                if pad > 0:
                    sample_i_indices = sample_i_indices[:-pad]

                # Add current sample
                expected_sample_indices.append(sample_i_indices)

            # As tensor
            expected_sample_indices = torch.tensor(expected_sample_indices)

            # Sample
            ctx = SamplingContext(n=num_samples, parent_indices=parent_indices)
            prod_layer.sample(context=ctx)
            self.assertTrue(
                (expected_sample_indices == ctx.parent_indices).all())
Ejemplo n.º 6
0
    def test_spn_mpe(self):

        # Define SPN
        leaf = distributions.Normal(in_features=2**3,
                                    out_channels=5,
                                    num_repetitions=1)
        sum_1 = layers.Sum(in_channels=5,
                           in_features=2**3,
                           out_channels=20,
                           num_repetitions=1)
        prd_1 = layers.Product(in_features=2**3,
                               cardinality=2,
                               num_repetitions=1)
        sum_2 = layers.Sum(in_channels=20,
                           in_features=2**2,
                           out_channels=20,
                           num_repetitions=1)
        prd_2 = layers.Product(in_features=2**2,
                               cardinality=2,
                               num_repetitions=1)
        sum_3 = layers.Sum(in_channels=20,
                           in_features=2**1,
                           out_channels=20,
                           num_repetitions=1)
        prd_3 = layers.Product(in_features=2**1,
                               cardinality=2,
                               num_repetitions=1)
        sum_4 = layers.Sum(in_channels=20,
                           in_features=2**0,
                           out_channels=1,
                           num_repetitions=1)

        sum_1._enable_input_cache()
        sum_2._enable_input_cache()
        sum_3._enable_input_cache()
        sum_4._enable_input_cache()

        # Test forward pass
        x_test = torch.randn(1, 2**3)

        x_test = leaf(x_test)
        x_test = sum_1(x_test)
        x_test = prd_1(x_test)
        x_test = sum_2(x_test)
        x_test = prd_2(x_test)
        x_test = sum_3(x_test)
        x_test = prd_3(x_test)
        res = sum_4(x_test)

        ctx = SamplingContext(n=x_test.shape[0], is_mpe=True)
        sum_4.sample(context=ctx)
        prd_3.sample(context=ctx)
        sum_3.sample(context=ctx)
        prd_2.sample(context=ctx)
        sum_2.sample(context=ctx)
        prd_1.sample(context=ctx)
        sum_1.sample(context=ctx)

        # Should be the same
        mpe_1 = leaf.sample(context=ctx)
        mpe_2 = leaf.sample(context=ctx)
        mpe_3 = leaf.sample(context=ctx)
        self.assertTrue(((mpe_1 - mpe_2).abs() < 1e-6).all())
        self.assertTrue(((mpe_2 - mpe_3).abs() < 1e-6).all())