예제 #1
0
    def test_nduplicates(self):

        # some tuples
        tuples = torch.tensor([[[5, 5], [1, 1], [2, 3], [1, 1]],
                               [[3, 2], [3, 2], [5, 5], [5, 5]]])

        # what they should look like after masking out the duplicates
        dedup = torch.tensor([[[5, 5], [1, 1], [2, 3], [0, 0]],
                              [[3, 2], [0, 0], [5, 5], [0, 0]]])

        # add a load of dimensions
        tuples = tuples[None, None, None, :, :, :].expand(3, 5, 7, 2, 4,
                                                          2).contiguous()
        dedup = dedup[None, None, None, :, :, :].expand(3, 5, 7, 2, 4,
                                                        2).contiguous()

        # find the duplicates
        dup = util.nduplicates(tuples)

        # mask them out
        tuples[dup, :] = tuples[dup, :] * 0
        self.assertEqual((tuples != dedup).sum(),
                         0)  # assert equal to expected

        # second test: explicitly test the bitmask returned by nduplicates
        tuples = torch.tensor([[[3, 1], [3, 2], [3, 1], [0, 3], [0, 2], [3, 0],
                                [0, 3], [0, 0]]])

        tuples = tuples[None, None, None, :, :, :].expand(8, 1, 7, 1, 8,
                                                          2).contiguous()

        self.assertEqual([0, 0, 1, 0, 0, 0, 1, 0],
                         list(
                             util.nduplicates(tuples)[0, 0, 0, :, :].view(-1)))

        # third test: single element tuples
        tuples = torch.tensor([[[5], [1], [2], [1]], [[3], [3], [5], [5]]])
        dedup = torch.tensor([[[5], [1], [2], [0]], [[3], [0], [5], [0]]])

        tuples = tuples[None, None, None, :, :, :].expand(3, 5, 7, 2, 4,
                                                          2).contiguous()
        dedup = dedup[None, None, None, :, :, :].expand(3, 5, 7, 2, 4,
                                                        2).contiguous()

        dup = util.nduplicates(tuples)

        tuples[dup, :] = tuples[dup, :] * 0
        self.assertEqual((tuples != dedup).sum(), 0)
예제 #2
0
    def test_nduplicates_recursion(self):
        """
        Reproducing observed recursion error
        :return:
        """

        # tensor of 6 1-tuples
        tuples = torch.tensor([[[[74], [75], [175], [246], [72], [72]]]])

        dedup = torch.tensor([[[[74], [75], [175], [246], [72], [0]]]])

        dup = util.nduplicates(tuples)

        tuples[dup, :] = tuples[dup, :] * 0
        self.assertEqual((tuples != dedup).sum(), 0)