def test_spn_sampling(self): # Define SPN leaf = distributions.Normal(in_features=2 ** 3, out_channels=5, num_repetitions=1) sum_1 = layers.Sum(in_channels=5, in_features=2 ** 3, out_channels=20, num_repetitions=1) prd_1 = layers.Product(in_features=2 ** 3, cardinality=2, num_repetitions=1) sum_2 = layers.Sum(in_channels=20, in_features=2 ** 2, out_channels=20, num_repetitions=1) prd_2 = layers.Product(in_features=2 ** 2, cardinality=2, num_repetitions=1) sum_3 = layers.Sum(in_channels=20, in_features=2 ** 1, out_channels=20, num_repetitions=1) prd_3 = layers.Product(in_features=2 ** 1, cardinality=2, num_repetitions=1) sum_4 = layers.Sum(in_channels=20, in_features=2 ** 0, out_channels=1, num_repetitions=1) # Test forward pass x_test = torch.randn(1, 2 ** 3) x_test = leaf(x_test) x_test = sum_1(x_test) x_test = prd_1(x_test) x_test = sum_2(x_test) x_test = prd_2(x_test) x_test = sum_3(x_test) x_test = prd_3(x_test) res = sum_4(x_test) # Sampling pass ctx = sum_4.sample(n=1000) prd_3.sample(context=ctx) sum_3.sample(context=ctx) prd_2.sample(context=ctx) sum_2.sample(context=ctx) prd_1.sample(context=ctx) sum_1.sample(context=ctx) samples = leaf.sample(context=ctx)
def create_pytorch_spn(n_feats): # Create SPN layers gauss = Normal(multiplicity=2, in_features=n_feats, in_channels=1) prod1 = layers.Product(in_features=n_feats, cardinality=2) sum1 = layers.Sum(in_features=n_feats / 2, in_channels=2, out_channels=1) prod2 = layers.Product(in_features=n_feats / 2, cardinality=n_feats // 2) # Stack SPN layers device = torch.device("cuda:0") pytorch_spn = nn.Sequential(gauss, prod1, sum1, prod2).to(device) return pytorch_spn
def test_product_layer(self): """Test the product layer forward pass.""" # Setup product layer in_features = 9 cardinality = 3 num_repetitions = 5 prod_layer = layers.Product(in_features=in_features, cardinality=cardinality, num_repetitions=num_repetitions) # Setup test input batch_size = 16 in_channels = 3 x = torch.rand(size=(batch_size, in_features, in_channels, num_repetitions)) # Expected result: expected_result = torch.ones(batch_size, in_features // cardinality, in_channels, num_repetitions) for n in range(batch_size): for d in range(0, in_features, cardinality): for c in range(in_channels): for r in range(num_repetitions): for i in range(cardinality): expected_result[n, d // cardinality, c, r] *= x[n, d + i, c, r] # Actual result result = prod_layer(x.log()).exp() # Run assertions self.assertTrue(result.shape[0] == batch_size) self.assertTrue(result.shape[1] == in_features // cardinality) self.assertTrue(result.shape[2] == in_channels) self.assertTrue(result.shape[3] == num_repetitions) self.assertTrue(((result - expected_result).abs() < 1e-6).all())
def test_product_shape_as_root_node(self): """Check that the product node has the correct sampling shape when used as root.""" prod_layer = layers.Product(in_features=10, cardinality=2, num_repetitions=1) ctx = SamplingContext(n=5) ctx = prod_layer.sample(context=ctx) self.assertTrue(ctx.parent_indices.shape[0] == 5) self.assertTrue(ctx.parent_indices.shape[1] == 1)
def test_prod_as_intermediate_node(self): # Product layer values in_features = 10 num_samples = 5 num_repetitions = 5 for cardinality in range(2, in_features): prod_layer = layers.Product(in_features=in_features, cardinality=cardinality, num_repetitions=num_repetitions) # Example parent indexes parent_indices = torch.randint(high=5, size=(num_samples, in_features)) # Create expected indexes: each index is repeated #cardinality times pad = (cardinality - in_features % cardinality) % cardinality expected_sample_indices = [] for j in range(num_samples): sample_i_indices = [] for i in parent_indices[j, :]: sample_i_indices += [i] * cardinality # Remove padding if pad > 0: sample_i_indices = sample_i_indices[:-pad] # Add current sample expected_sample_indices.append(sample_i_indices) # As tensor expected_sample_indices = torch.tensor(expected_sample_indices) # Sample ctx = SamplingContext(n=num_samples, parent_indices=parent_indices) prod_layer.sample(context=ctx) self.assertTrue( (expected_sample_indices == ctx.parent_indices).all())
def test_spn_mpe(self): # Define SPN leaf = distributions.Normal(in_features=2**3, out_channels=5, num_repetitions=1) sum_1 = layers.Sum(in_channels=5, in_features=2**3, out_channels=20, num_repetitions=1) prd_1 = layers.Product(in_features=2**3, cardinality=2, num_repetitions=1) sum_2 = layers.Sum(in_channels=20, in_features=2**2, out_channels=20, num_repetitions=1) prd_2 = layers.Product(in_features=2**2, cardinality=2, num_repetitions=1) sum_3 = layers.Sum(in_channels=20, in_features=2**1, out_channels=20, num_repetitions=1) prd_3 = layers.Product(in_features=2**1, cardinality=2, num_repetitions=1) sum_4 = layers.Sum(in_channels=20, in_features=2**0, out_channels=1, num_repetitions=1) sum_1._enable_input_cache() sum_2._enable_input_cache() sum_3._enable_input_cache() sum_4._enable_input_cache() # Test forward pass x_test = torch.randn(1, 2**3) x_test = leaf(x_test) x_test = sum_1(x_test) x_test = prd_1(x_test) x_test = sum_2(x_test) x_test = prd_2(x_test) x_test = sum_3(x_test) x_test = prd_3(x_test) res = sum_4(x_test) ctx = SamplingContext(n=x_test.shape[0], is_mpe=True) sum_4.sample(context=ctx) prd_3.sample(context=ctx) sum_3.sample(context=ctx) prd_2.sample(context=ctx) sum_2.sample(context=ctx) prd_1.sample(context=ctx) sum_1.sample(context=ctx) # Should be the same mpe_1 = leaf.sample(context=ctx) mpe_2 = leaf.sample(context=ctx) mpe_3 = leaf.sample(context=ctx) self.assertTrue(((mpe_1 - mpe_2).abs() < 1e-6).all()) self.assertTrue(((mpe_2 - mpe_3).abs() < 1e-6).all())