def test_spn_sampling(self): # Define SPN leaf = distributions.Normal(in_features=2 ** 3, out_channels=5, num_repetitions=1) sum_1 = layers.Sum(in_channels=5, in_features=2 ** 3, out_channels=20, num_repetitions=1) prd_1 = layers.Product(in_features=2 ** 3, cardinality=2, num_repetitions=1) sum_2 = layers.Sum(in_channels=20, in_features=2 ** 2, out_channels=20, num_repetitions=1) prd_2 = layers.Product(in_features=2 ** 2, cardinality=2, num_repetitions=1) sum_3 = layers.Sum(in_channels=20, in_features=2 ** 1, out_channels=20, num_repetitions=1) prd_3 = layers.Product(in_features=2 ** 1, cardinality=2, num_repetitions=1) sum_4 = layers.Sum(in_channels=20, in_features=2 ** 0, out_channels=1, num_repetitions=1) # Test forward pass x_test = torch.randn(1, 2 ** 3) x_test = leaf(x_test) x_test = sum_1(x_test) x_test = prd_1(x_test) x_test = sum_2(x_test) x_test = prd_2(x_test) x_test = sum_3(x_test) x_test = prd_3(x_test) res = sum_4(x_test) # Sampling pass ctx = sum_4.sample(n=1000) prd_3.sample(context=ctx) sum_3.sample(context=ctx) prd_2.sample(context=ctx) sum_2.sample(context=ctx) prd_1.sample(context=ctx) sum_1.sample(context=ctx) samples = leaf.sample(context=ctx)
def test_sum_as_intermediate_node(self): """Check that sum node returns the correct sample indices when used as indermediate node.""" # Some values for the sum layer in_features = 10 in_channels = 3 out_channels = 5 num_repetitions = 7 n = 2 parent_indices = torch.randint(out_channels, size=(n, in_features)) # Create sum layer sum_layer = layers.Sum( in_features=in_features, in_channels=in_channels, out_channels=out_channels, num_repetitions=num_repetitions ) # Choose `in_features` number of random indexes from 0 to in_channels-1 which will have probability of 1.0 in # the sum layer weight tensor rand_indxs = torch.randint(in_channels, size=(in_features, num_repetitions)) rep_idxs = torch.randint(num_repetitions, size=(n,)) # Artificially set sum weights (probabilities) to 1.0 weights = torch.zeros(in_features, in_channels, out_channels, num_repetitions) for r in range(num_repetitions): weights[range(in_features), rand_indxs[:, r], :, r] = 1.0 sum_layer.weights = nn.Parameter(torch.log(weights)) # Perform sampling ctx = SamplingContext(n=n, parent_indices=parent_indices, repetition_indices=rep_idxs) sum_layer.sample(context=ctx) # Assert that the sample indexes are those where the weights were set to 1.0 for i in range(n): self.assertTrue((rand_indxs[:, rep_idxs[i]] == ctx.parent_indices[i, :]).all())
def create_pytorch_spn(n_feats): # Create SPN layers gauss = Normal(multiplicity=2, in_features=n_feats, in_channels=1) prod1 = layers.Product(in_features=n_feats, cardinality=2) sum1 = layers.Sum(in_features=n_feats / 2, in_channels=2, out_channels=1) prod2 = layers.Product(in_features=n_feats / 2, cardinality=n_feats // 2) # Stack SPN layers device = torch.device("cuda:0") pytorch_spn = nn.Sequential(gauss, prod1, sum1, prod2).to(device) return pytorch_spn
def test_sum_shape_as_root_node(self): """Check that the sum node has the correct sampling shape when used as root.""" n = 5 num_repetitions = 1 for in_channels in [1, 5, 10]: for in_features in [1, 5, 10]: sum_layer = layers.Sum( in_channels=in_channels, out_channels=1, in_features=in_features, num_repetitions=num_repetitions ) ctx = sum_layer.sample(n=n) self.assertTrue(ctx.parent_indices.shape[0] == n) self.assertTrue(ctx.parent_indices.shape[1] == in_features)
def test_sum_layer(self): """Test the forward pass of a sum layer""" # Setup layer in_channels = 8 out_channels = 7 in_features = 3 num_repetitions = 5 sum_layer = layers.Sum(in_channels=in_channels, out_channels=out_channels, in_features=in_features, num_repetitions=num_repetitions) w = torch.rand(in_features, in_channels, out_channels, num_repetitions) # Set the sum layer parameters sum_layer.weights = nn.Parameter(w) # Apply softmax once again since Sum forward pass uses F.log_softmax internally to project random weights # back into valid ranges w = F.softmax(w, dim=1) # Setup test input batch_size = 16 x = torch.rand(size=(batch_size, in_features, in_channels, num_repetitions)) # Expected outcome expected_result = torch.zeros(batch_size, in_features, out_channels, num_repetitions) for n in range(batch_size): for d in range(in_features): for oc in range(out_channels): for r in range(num_repetitions): expected_result[n, d, oc, r] = x[n, d, :, r] @ w[d, :, oc, r] # Do forward pass: apply log as sum layer operates in log space. Exp() afterwards to make it comparable to the # expected result result = sum_layer(x.log()).exp() # Run assertions self.assertTrue(result.shape[0] == batch_size) self.assertTrue(result.shape[1] == in_features) self.assertTrue(result.shape[2] == out_channels) self.assertTrue(result.shape[3] == num_repetitions) self.assertTrue(((result - expected_result).abs() < 1e-6).all())
def test_spn_mpe(self): # Define SPN leaf = distributions.Normal(in_features=2**3, out_channels=5, num_repetitions=1) sum_1 = layers.Sum(in_channels=5, in_features=2**3, out_channels=20, num_repetitions=1) prd_1 = layers.Product(in_features=2**3, cardinality=2, num_repetitions=1) sum_2 = layers.Sum(in_channels=20, in_features=2**2, out_channels=20, num_repetitions=1) prd_2 = layers.Product(in_features=2**2, cardinality=2, num_repetitions=1) sum_3 = layers.Sum(in_channels=20, in_features=2**1, out_channels=20, num_repetitions=1) prd_3 = layers.Product(in_features=2**1, cardinality=2, num_repetitions=1) sum_4 = layers.Sum(in_channels=20, in_features=2**0, out_channels=1, num_repetitions=1) sum_1._enable_input_cache() sum_2._enable_input_cache() sum_3._enable_input_cache() sum_4._enable_input_cache() # Test forward pass x_test = torch.randn(1, 2**3) x_test = leaf(x_test) x_test = sum_1(x_test) x_test = prd_1(x_test) x_test = sum_2(x_test) x_test = prd_2(x_test) x_test = sum_3(x_test) x_test = prd_3(x_test) res = sum_4(x_test) ctx = SamplingContext(n=x_test.shape[0], is_mpe=True) sum_4.sample(context=ctx) prd_3.sample(context=ctx) sum_3.sample(context=ctx) prd_2.sample(context=ctx) sum_2.sample(context=ctx) prd_1.sample(context=ctx) sum_1.sample(context=ctx) # Should be the same mpe_1 = leaf.sample(context=ctx) mpe_2 = leaf.sample(context=ctx) mpe_3 = leaf.sample(context=ctx) self.assertTrue(((mpe_1 - mpe_2).abs() < 1e-6).all()) self.assertTrue(((mpe_2 - mpe_3).abs() < 1e-6).all())