def test_nduplicates(self): # some tuples tuples = torch.tensor([[[5, 5], [1, 1], [2, 3], [1, 1]], [[3, 2], [3, 2], [5, 5], [5, 5]]]) # what they should look like after masking out the duplicates dedup = torch.tensor([[[5, 5], [1, 1], [2, 3], [0, 0]], [[3, 2], [0, 0], [5, 5], [0, 0]]]) # add a load of dimensions tuples = tuples[None, None, None, :, :, :].expand(3, 5, 7, 2, 4, 2).contiguous() dedup = dedup[None, None, None, :, :, :].expand(3, 5, 7, 2, 4, 2).contiguous() # find the duplicates dup = util.nduplicates(tuples) # mask them out tuples[dup, :] = tuples[dup, :] * 0 self.assertEqual((tuples != dedup).sum(), 0) # assert equal to expected # second test: explicitly test the bitmask returned by nduplicates tuples = torch.tensor([[[3, 1], [3, 2], [3, 1], [0, 3], [0, 2], [3, 0], [0, 3], [0, 0]]]) tuples = tuples[None, None, None, :, :, :].expand(8, 1, 7, 1, 8, 2).contiguous() self.assertEqual([0, 0, 1, 0, 0, 0, 1, 0], list( util.nduplicates(tuples)[0, 0, 0, :, :].view(-1))) # third test: single element tuples tuples = torch.tensor([[[5], [1], [2], [1]], [[3], [3], [5], [5]]]) dedup = torch.tensor([[[5], [1], [2], [0]], [[3], [0], [5], [0]]]) tuples = tuples[None, None, None, :, :, :].expand(3, 5, 7, 2, 4, 2).contiguous() dedup = dedup[None, None, None, :, :, :].expand(3, 5, 7, 2, 4, 2).contiguous() dup = util.nduplicates(tuples) tuples[dup, :] = tuples[dup, :] * 0 self.assertEqual((tuples != dedup).sum(), 0)
def test_nduplicates_recursion(self): """ Reproducing observed recursion error :return: """ # tensor of 6 1-tuples tuples = torch.tensor([[[[74], [75], [175], [246], [72], [72]]]]) dedup = torch.tensor([[[[74], [75], [175], [246], [72], [0]]]]) dup = util.nduplicates(tuples) tuples[dup, :] = tuples[dup, :] * 0 self.assertEqual((tuples != dedup).sum(), 0)