Example #1
0
    def test_sum_as_intermediate_node(self):
        """Check that sum node returns the correct sample indices when used as indermediate node."""
        # Some values for the sum layer
        in_features = 10
        in_channels = 3
        out_channels = 5
        num_repetitions = 7
        n = 2
        parent_indices = torch.randint(out_channels, size=(n, in_features))

        # Create sum layer
        sum_layer = layers.Sum(
            in_features=in_features, in_channels=in_channels, out_channels=out_channels, num_repetitions=num_repetitions
        )

        # Choose `in_features` number of random indexes from 0 to in_channels-1 which will have probability of 1.0 in
        # the sum layer weight tensor
        rand_indxs = torch.randint(in_channels, size=(in_features, num_repetitions))
        rep_idxs = torch.randint(num_repetitions, size=(n,))

        # Artificially set sum weights (probabilities) to 1.0
        weights = torch.zeros(in_features, in_channels, out_channels, num_repetitions)
        for r in range(num_repetitions):
            weights[range(in_features), rand_indxs[:, r], :, r] = 1.0
        sum_layer.weights = nn.Parameter(torch.log(weights))

        # Perform sampling
        ctx = SamplingContext(n=n, parent_indices=parent_indices, repetition_indices=rep_idxs)
        sum_layer.sample(context=ctx)

        # Assert that the sample indexes are those where the weights were set to 1.0
        for i in range(n):
            self.assertTrue((rand_indxs[:, rep_idxs[i]] == ctx.parent_indices[i, :]).all())
Example #2
0
    def test_normal_leaf(self):
        # Setup leaf layer
        out_channels = 10
        in_features = 10
        num_repetitions = 5
        leaf = distributions.Normal(out_channels=out_channels, in_features=in_features, num_repetitions=num_repetitions)

        # Set leaf layer mean to some random int
        leaf.means.data = torch.randint(
            low=-100, high=100, size=(1, in_features, out_channels, num_repetitions)
        ).float()
        # Set leaf layer std to 0 such that the samples will all be the mean (so we can actually make assertions in the end)
        leaf.stds.data = torch.zeros(size=(1, in_features, out_channels, num_repetitions)).float()

        # Create some random indices into the out_channels axis
        parent_indices = torch.randint(high=out_channels, size=(1, in_features,))
        repetition_indices = torch.randint(high=num_repetitions, size=(1,))

        # Perform sampling
        ctx = SamplingContext(n=1, parent_indices=parent_indices, repetition_indices=repetition_indices)
        result = leaf.sample(context=ctx)

        # Expected sampling
        expected_result = leaf.means.data[:, range(in_features), parent_indices, repetition_indices[0]]

        # Run assertions
        self.assertTrue(((result - expected_result).abs() < 1e-6).all())
Example #3
0
 def test_product_shape_as_root_node(self):
     """Check that the product node has the correct sampling shape when used as root."""
     prod_layer = layers.Product(in_features=10,
                                 cardinality=2,
                                 num_repetitions=1)
     ctx = SamplingContext(n=5)
     ctx = prod_layer.sample(context=ctx)
     self.assertTrue(ctx.parent_indices.shape[0] == 5)
     self.assertTrue(ctx.parent_indices.shape[1] == 1)
Example #4
0
    def sample(self,
               n: int = None,
               context: SamplingContext = None) -> SamplingContext:
        """Method to sample from this layer, based on the parents output.

        Args:
            n (int): Number of instances to sample.
            indices (torch.Tensor): Parent sampling output.
        Returns:
            torch.Tensor: Index into tensor which paths should be followed.
                          Output should be of size: in_features, out_channels.
        """

        # If this is a root node
        if context is None:
            if self.num_repetitions == 1:
                # If there is only a single repetition, create new sampling context
                return SamplingContext(
                    n=n,
                    parent_indices=torch.zeros(n,
                                               1,
                                               dtype=int,
                                               device=self.__device),
                    repetition_indices=torch.zeros(n,
                                                   dtype=int,
                                                   device=self.__device),
                )
            else:
                raise Exception(
                    "Cannot start sampling from Product layer with num_repetitions > 1 and no context given."
                )
        else:
            # Repeat the parent indices, e.g. [0, 2, 3] -> [0, 0, 2, 2, 3, 3] depending on the cardinality
            indices = torch.repeat_interleave(context.parent_indices,
                                              repeats=self.cardinality,
                                              dim=1)

            # Remove padding
            if self._pad:
                indices = indices[:, :-self._pad]

            context.parent_indices = indices
            return context
Example #5
0
    def sample(self,
               n: int = None,
               context: SamplingContext = None) -> SamplingContext:
        """Method to sample from this layer, based on the parents output.

        Args:
            n: Number of samples.
            indices (torch.Tensor): Parent sampling output
        Returns:
            torch.Tensor: Index into tensor which paths should be followed.
                          Output should be of size: in_features, out_channels.
        """

        # If this is a root node
        if context is None:
            if self.num_repetitions == 1:
                # If there is only a single repetition, create new sampling context
                return SamplingContext(
                    n=n,
                    parent_indices=torch.zeros(n,
                                               1,
                                               dtype=int,
                                               device=self.__device),
                    repetition_indices=torch.zeros(n,
                                                   dtype=int,
                                                   device=self.__device),
                )
            else:
                raise Exception(
                    "Cannot start sampling from CrossProduct layer with num_repetitions > 1 and no context given."
                )

        # Map flattened indexes back to coordinates to obtain the chosen input_channel for each feature
        indices = self.unraveled_channel_indices[context.parent_indices]
        indices = indices.view(indices.shape[0], -1)

        # Remove padding
        if self._pad:
            indices = indices[:, :-self._pad]

        context.parent_indices = indices
        return context
Example #6
0
 def _check_repetition_indices(self, context: SamplingContext):
     assert context.repetition_indices.shape[
         0] == context.parent_indices.shape[0]
     if self.num_repetitions > 1 and context.repetition_indices is None:
         raise Exception(
             f"Sum layer has multiple repetitions (num_repetitions=={self.num_repetitions}) but repetition_indices argument was None, expected a Long tensor size #samples."
         )
     if self.num_repetitions == 1 and context.repetition_indices is None:
         context.repetition_indices = torch.zeros(context.n,
                                                  dtype=int,
                                                  device=self.__device)
Example #7
0
    def test_spn_sampling(self):

        # Define SPN
        leaf = distributions.Normal(in_features=2**3,
                                    out_channels=5,
                                    num_repetitions=1)
        sum_1 = layers.Sum(in_channels=5,
                           in_features=2**3,
                           out_channels=20,
                           num_repetitions=1)
        prd_1 = layers.Product(in_features=2**3,
                               cardinality=2,
                               num_repetitions=1)
        sum_2 = layers.Sum(in_channels=20,
                           in_features=2**2,
                           out_channels=20,
                           num_repetitions=1)
        prd_2 = layers.Product(in_features=2**2,
                               cardinality=2,
                               num_repetitions=1)
        sum_3 = layers.Sum(in_channels=20,
                           in_features=2**1,
                           out_channels=20,
                           num_repetitions=1)
        prd_3 = layers.Product(in_features=2**1,
                               cardinality=2,
                               num_repetitions=1)
        sum_4 = layers.Sum(in_channels=20,
                           in_features=2**0,
                           out_channels=1,
                           num_repetitions=1)

        # Test forward pass
        x_test = torch.randn(1, 2**3)

        x_test = leaf(x_test)
        x_test = sum_1(x_test)
        x_test = prd_1(x_test)
        x_test = sum_2(x_test)
        x_test = prd_2(x_test)
        x_test = sum_3(x_test)
        x_test = prd_3(x_test)
        res = sum_4(x_test)

        # Sampling pass
        ctx = SamplingContext(n=1000)
        sum_4.sample(context=ctx)
        prd_3.sample(context=ctx)
        sum_3.sample(context=ctx)
        prd_2.sample(context=ctx)
        sum_2.sample(context=ctx)
        prd_1.sample(context=ctx)
        sum_1.sample(context=ctx)
        samples = leaf.sample(context=ctx)
Example #8
0
    def sample(self,
               n: int = None,
               context: SamplingContext = None) -> torch.Tensor:
        context = self.prod.sample(context=context)

        # Remove padding
        if self._pad:
            context.parent_indices = context.parent_indices[:, :-self._pad]

        samples = self.base_leaf.sample(context=context)
        return samples
Example #9
0
 def test_sum_shape_as_root_node(self):
     """Check that the sum node has the correct sampling shape when used as root."""
     n = 5
     num_repetitions = 1
     for in_channels in [1, 5, 10]:
         for in_features in [1, 5, 10]:
             sum_layer = layers.Sum(in_channels=in_channels,
                                    out_channels=1,
                                    in_features=in_features,
                                    num_repetitions=num_repetitions)
             ctx = SamplingContext(n=n)
             ctx = sum_layer.sample(context=ctx)
             self.assertTrue(ctx.parent_indices.shape[0] == n)
             self.assertTrue(ctx.parent_indices.shape[1] == in_features)
Example #10
0
    def test_prod_as_intermediate_node(self):
        # Product layer values
        in_features = 10
        num_samples = 5
        num_repetitions = 5
        for cardinality in range(2, in_features):
            prod_layer = layers.Product(in_features=in_features,
                                        cardinality=cardinality,
                                        num_repetitions=num_repetitions)

            # Example parent indexes
            parent_indices = torch.randint(high=5,
                                           size=(num_samples, in_features))

            # Create expected indexes: each index is repeated #cardinality times
            pad = (cardinality - in_features % cardinality) % cardinality
            expected_sample_indices = []
            for j in range(num_samples):

                sample_i_indices = []
                for i in parent_indices[j, :]:
                    sample_i_indices += [i] * cardinality

                # Remove padding
                if pad > 0:
                    sample_i_indices = sample_i_indices[:-pad]

                # Add current sample
                expected_sample_indices.append(sample_i_indices)

            # As tensor
            expected_sample_indices = torch.tensor(expected_sample_indices)

            # Sample
            ctx = SamplingContext(n=num_samples, parent_indices=parent_indices)
            prod_layer.sample(context=ctx)
            self.assertTrue(
                (expected_sample_indices == ctx.parent_indices).all())
Example #11
0
    def sample(self,
               n: int = None,
               class_index=None,
               evidence: torch.Tensor = None):
        """
        Sample from the distribution represented by this SPN.

        Possible valid inputs:

        - `n`: Generates `n` samples.
        - `n` and `class_index (int)`: Generates `n` samples from P(X | C = class_index).
        - `class_index (List[int])`: Generates `len(class_index)` samples. Each index `c_i` in `class_index` is mapped
            to a sample from P(X | C = c_i)
        - `evidence`: If evidence is given, samples conditionally and fill NaN values.

        Args:
            n: Number of samples to generate.
            class_index: Class index. Can be either an int in combination with a value for `n` which will result in `n`
                samples from P(X | C = class_index). Or can be a list of ints which will map each index `c_i` in the
                list to a sample from P(X | C = c_i).
            evidence: Evidence that can be provided to condition the samples. If evidence is given, `n` and
                `class_index` must be `None`. Evidence must contain NaN values which will be imputed according to the
                distribution represented by the SPN. The result will contain the evidence and replace all NaNs with the
                sampled values.

        Returns:
            torch.Tensor: Samples generated according to the distribution specified by the SPN.

        """
        assert class_index is None or evidence is None, "Cannot provide both, evidence and class indices."
        assert n is None or evidence is None, "Cannot provide both, number of samples to generate (n) and evidence."

        # Check if evidence contains nans
        if evidence is not None:
            assert (evidence != evidence).any(), "Evidence has no NaN values."

            # Set n to the number of samples in the evidence
            n = evidence.shape[0]

        with provide_evidence(self, evidence):  # May be None but that's ok
            # If class is given, use it as base index
            if class_index is not None:
                if isinstance(class_index, list):
                    indices = torch.tensor(class_index,
                                           device=self.__device).view(-1, 1)
                    n = indices.shape[0]
                else:
                    indices = torch.empty(size=(n, 1), device=self.__device)
                    indices.fill_(class_index)

                # Create new sampling context
                ctx = SamplingContext(n=n,
                                      parent_indices=indices,
                                      repetition_indices=None)
            else:
                # Start sampling one of the C root nodes TODO: check what happens if C=1
                ctx = self._sampling_root.sample(n=n)

            # Sample from RatSpn root layer: Results are indices into the stacked output channels of all repetitions
            ctx.repetition_indices = torch.zeros(n,
                                                 dtype=int,
                                                 device=self.__device)
            ctx = self.root.sample(context=ctx)

            # Indexes will now point to the stacked channels of all repetitions (R * S^2 (if D > 1)
            # or R * I^2 (else)).
            root_in_channels = self.root.in_channels // self.config.R
            # Obtain repetition indices
            ctx.repetition_indices = (ctx.parent_indices //
                                      root_in_channels).squeeze(1)
            # Shift indices
            ctx.parent_indices = ctx.parent_indices % root_in_channels

            # Now each sample in `indices` belongs to one repetition, index in `repetition_indices`

            # Continue at layers
            # Sample inner modules
            for layer in reversed(self._inner_layers):
                if isinstance(layer, Sum):
                    ctx = layer.sample(context=ctx)
                elif isinstance(layer, CrossProduct):
                    ctx = layer.sample(context=ctx)
                else:
                    raise Exception(
                        "Only Sum or CrossProduct is allowed as intermediate layer."
                    )

            # Sample leaf
            samples = self._leaf.sample(context=ctx)

            # Invert permutation
            for i in range(n):
                rep_index = ctx.repetition_indices[i]
                inv_rand_indices = invert_permutation(
                    self.rand_indices[:, rep_index])
                samples[i, :] = samples[i, inv_rand_indices]

            if evidence is not None:
                # Update NaN entries in evidence with the sampled values
                nan_indices = torch.isnan(evidence)

                # First make a copy such that the original object is not changed
                evidence = evidence.clone()
                evidence[nan_indices] = samples[nan_indices]
                return evidence
            else:
                return samples
Example #12
0
    def sample(self,
               n: int = None,
               context: SamplingContext = None) -> SamplingContext:
        """Method to sample from this layer, based on the parents output.

        Output is always a vector of indices into the channels.

        Args:
            repetition_indices (List[int]): An index into the repetition axis for each sample.
                Can be None if `num_repetitions==1`.
            indices (torch.Tensor): Parent sampling output.
            n (int): Number of samples.
        Returns:
            torch.Tensor: Index into tensor which paths should be followed.
        """

        # Sum weights are of shape: [D, IC, OC, R]
        # We now want to use `indices` to access one in_channel for each in_feature x out_channels block
        # index is of size in_feature
        weights = self.weights.data
        d, ic, oc, r = weights.shape
        n = context.n

        # Create sampling context if this is a root layer
        if context.is_root:
            assert oc == 1 and r == 1, "Cannot start sampling from non-root layer."

            # Initialize rep indices
            context.repetition_indices = torch.zeros(n,
                                                     dtype=int,
                                                     device=self.__device)

            # Select weights, repeat n times along the last dimension
            weights = weights[:, :, [0] * n, 0]  # Shape: [D, IC, N]

            # Move sample dimension to the first axis: [feat, channels, batch] -> [batch, feat, channels]
            weights = weights.permute(2, 0, 1)  # Shape: [N, D, IC]
        else:
            # If this is not the root node, use the paths (out channels), specified by the parent layer
            self._check_repetition_indices(context)

            tmp = torch.zeros(n, d, ic, device=self.__device)
            for i in range(n):
                tmp[i, :, :] = weights[range(self.in_features), :,
                                       context.parent_indices[i],
                                       context.repetition_indices[i]]
            weights = tmp

        # Check dimensions
        assert weights.shape == (n, d, ic)

        # Apply softmax to ensure they are proper probabilities
        log_weights = F.log_softmax(weights, dim=2)

        # If evidence is given, adjust the weights with the likelihoods of the observed paths
        if self._is_input_cache_enabled and self._input_cache is not None:
            for i in range(n):
                # Reweight the i-th samples weights by its likelihood values at the correct repetition
                log_weights[i, :, :] += self._input_cache[
                    i, :, :, context.repetition_indices[i]]

        # If sampling context is MPE, set max weight to 1 and rest to zero, such that the maximum index will be sampled
        if context.is_mpe:
            # Get index of largest weight along in-channel dimension
            indices = log_weights.argmax(dim=2)
        else:
            # Create categorical distribution and use weights as logits.
            #
            # Use the Gumble-Softmax trick to obtain one-hot indices of the categorical distribution
            # represented by the given logits. (Use Gumble-Softmax instead of Categorical
            # to allow for gradients).
            #
            # The code below is an approximation of:
            #
            # >> dist = torch.distributions.Categorical(logits=log_weights)
            # >> indices = dist.sample()

            cats = torch.arange(ic, device=log_weights.device)
            one_hot = F.gumbel_softmax(logits=log_weights, hard=True, dim=-1)
            indices = (one_hot * cats).sum(-1).long()

        # Update parent indices
        context.parent_indices = indices

        return context
Example #13
0
    def test_spn_mpe(self):

        # Define SPN
        leaf = distributions.Normal(in_features=2**3,
                                    out_channels=5,
                                    num_repetitions=1)
        sum_1 = layers.Sum(in_channels=5,
                           in_features=2**3,
                           out_channels=20,
                           num_repetitions=1)
        prd_1 = layers.Product(in_features=2**3,
                               cardinality=2,
                               num_repetitions=1)
        sum_2 = layers.Sum(in_channels=20,
                           in_features=2**2,
                           out_channels=20,
                           num_repetitions=1)
        prd_2 = layers.Product(in_features=2**2,
                               cardinality=2,
                               num_repetitions=1)
        sum_3 = layers.Sum(in_channels=20,
                           in_features=2**1,
                           out_channels=20,
                           num_repetitions=1)
        prd_3 = layers.Product(in_features=2**1,
                               cardinality=2,
                               num_repetitions=1)
        sum_4 = layers.Sum(in_channels=20,
                           in_features=2**0,
                           out_channels=1,
                           num_repetitions=1)

        sum_1._enable_input_cache()
        sum_2._enable_input_cache()
        sum_3._enable_input_cache()
        sum_4._enable_input_cache()

        # Test forward pass
        x_test = torch.randn(1, 2**3)

        x_test = leaf(x_test)
        x_test = sum_1(x_test)
        x_test = prd_1(x_test)
        x_test = sum_2(x_test)
        x_test = prd_2(x_test)
        x_test = sum_3(x_test)
        x_test = prd_3(x_test)
        res = sum_4(x_test)

        ctx = SamplingContext(n=x_test.shape[0], is_mpe=True)
        sum_4.sample(context=ctx)
        prd_3.sample(context=ctx)
        sum_3.sample(context=ctx)
        prd_2.sample(context=ctx)
        sum_2.sample(context=ctx)
        prd_1.sample(context=ctx)
        sum_1.sample(context=ctx)

        # Should be the same
        mpe_1 = leaf.sample(context=ctx)
        mpe_2 = leaf.sample(context=ctx)
        mpe_3 = leaf.sample(context=ctx)
        self.assertTrue(((mpe_1 - mpe_2).abs() < 1e-6).all())
        self.assertTrue(((mpe_2 - mpe_3).abs() < 1e-6).all())
Example #14
0
    def sample(self,
               n: int = None,
               context: SamplingContext = None) -> SamplingContext:
        """Method to sample from this layer, based on the parents output.

        Output is always a vector of indices into the channels.

        Args:
            repetition_indices (List[int]): An index into the repetition axis for each sample.
                Can be None if `num_repetitions==1`.
            indices (torch.Tensor): Parent sampling output.
            n (int): Number of samples.
        Returns:
            torch.Tensor: Index into tensor which paths should be followed.
        """

        # Sum weights are of shape: [D, IC, OC, R]
        # We now want to use `indices` to access one in_channel for each in_feature x out_channels block
        # index is of size in_feature
        weights = self.weights.data
        d, ic, oc, r = weights.shape

        # Create sampling context if this is a root layer
        if context is None:
            assert oc == 1 and r == 1, "Cannot start sampling from non-root layer."
            context = SamplingContext(n=n,
                                      parent_indices=None,
                                      repetition_indices=torch.zeros(
                                          n, dtype=int, device=self.__device))

            # Select weights, repeat n times along the last dimension
            weights = weights[:, :, [0] * n, 0]  # Shape: [D, IC, N]

            # Move sample dimension to the first axis: [feat, channels, batch] -> [batch, feat, channels]
            weights = weights.permute(2, 0, 1)  # Shape: [N, D, IC]
        else:
            # If this is not the root node, use the paths (out channels), specified by the parent layer
            self._check_repetition_indices(context)

            n = context.n
            tmp = torch.zeros(n, d, ic, device=self.__device)
            for i in range(n):
                tmp[i, :, :] = weights[range(self.in_features), :,
                                       context.parent_indices[i],
                                       context.repetition_indices[i]]
            weights = tmp

        # Check dimensions
        assert weights.shape == (n, d, ic)

        # Apply softmax to ensure they are proper probabilities
        log_weights = F.log_softmax(weights, dim=2)

        # If evidence is given, adjust the weights with the likelihoods of the observed paths
        if self._is_sampling_input_cache_enabled and self._sampling_input_cache is not None:
            for i in range(n):
                # Reweight the i-th samples weights by its likelihood values at the correct repetition
                log_weights[i, :, :] += self._sampling_input_cache[
                    i, :, :, context.repetition_indices[i]]

        # Create categorical distribution and use weights as logits
        dist = torch.distributions.Categorical(logits=log_weights)
        indices = dist.sample()

        # Update parent indices
        context.parent_indices = indices

        return context