def test_pair_ind(self):
        data1 = Data(pos=torch.randn(100, 3))
        data2 = Data(pos=torch.randn(114, 3))
        pair1 = Pair.make_pair(data1, data2)
        pair1.pair_ind = torch.tensor([[0, 1], [99, 36], [98, 113], [54, 29],
                                       [10, 110], [1, 0]])
        data3 = Data(pos=torch.randn(102, 3))
        data4 = Data(pos=torch.randn(104, 3))
        pair2 = Pair.make_pair(data3, data4)
        pair2.pair_ind = torch.tensor([[0, 1], [45, 28], [101, 36], [98, 1],
                                       [14, 99], [34, 52], [1, 0]])
        data5 = Data(pos=torch.randn(128, 3))
        data6 = Data(pos=torch.randn(2102, 3))
        pair3 = Pair.make_pair(data5, data6)
        pair3.pair_ind = torch.tensor([[0, 1], [100, 1000], [1, 0]])

        batch = PairBatch.from_data_list([pair1, pair2, pair3])
        expected_pair_ind = torch.tensor([[0, 1], [99, 36], [98, 113],
                                          [54, 29], [10, 110], [1, 0],
                                          [0 + 100, 1 + 114],
                                          [45 + 100, 28 + 114],
                                          [101 + 100, 36 + 114],
                                          [98 + 100, 1 + 114],
                                          [14 + 100, 99 + 114],
                                          [34 + 100, 52 + 114],
                                          [1 + 100, 0 + 114],
                                          [0 + 100 + 102, 1 + 114 + 104],
                                          [100 + 100 + 102, 1000 + 114 + 104],
                                          [1 + 100 + 102,
                                           0 + 114 + 104]]).to(torch.long)
        npt.assert_almost_equal(batch.pair_ind.numpy(),
                                expected_pair_ind.numpy())
 def test_pair_batch(self):
     d1 = Data(x=torch.tensor([1]), pos=torch.tensor([1]))
     d2 = Data(x=torch.tensor([2]), pos=torch.tensor([4]))
     d3 = Data(x=torch.tensor([3]), pos=torch.tensor([9]))
     d4 = Data(x=torch.tensor([4]), pos=torch.tensor([16]))
     p1 = Pair.make_pair(d1, d2)
     p2 = Pair.make_pair(d3, d4)
     batch = PairBatch.from_data_list([p1, p2])
     tt.assert_allclose(batch.x, torch.tensor([1, 3]))
     tt.assert_allclose(batch.pos, torch.tensor([1, 9]))
     tt.assert_allclose(batch.batch, torch.tensor([0, 1]))
     tt.assert_allclose(batch.x_target, torch.tensor([2, 4]))
     tt.assert_allclose(batch.pos_target, torch.tensor([4, 16]))
     tt.assert_allclose(batch.batch_target, torch.tensor([0, 1]))
    def _get_collate_function(conv_type, is_multiscale):

        is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type)

        if is_multiscale:
            if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower():
                return lambda datalist: PairMultiScaleBatch.from_data_list(datalist)
            else:
                raise NotImplementedError(
                    "MultiscaleTransform is activated and supported only for partial_dense format"
                )

        if is_dense:
            return lambda datalist: SimpleBatch.from_data_list(datalist)
        else:
            return lambda datalist: PairBatch.from_data_list(datalist)
    def __getitem__(self, index):

        if self._ms_transform:
            return PairMultiScaleBatch.from_data_list(self.datalist)
        else:
            return PairBatch.from_data_list(self.datalist)