def test_ms_pair_batch(self):
        x = torch.tensor([1])
        pos = x
        x_target = torch.tensor([2])
        pos_target = x
        ms = [Data(x=x, pos=pos), Data(x=4 * x, pos=4 * pos)]
        ms_target = [
            Data(x=x_target, pos=pos_target),
            Data(x=4 * x_target, pos=4 * pos_target)
        ]
        data1 = MultiScalePair(x=x,
                               pos=pos,
                               multiscale=ms,
                               x_target=x_target,
                               pos_target=pos_target,
                               multiscale_target=ms_target)

        x = torch.tensor([3])
        pos = x
        x_target = torch.tensor([4])
        pos_target = x
        ms = [Data(x=x, pos=pos), Data(x=4 * x, pos=4 * pos)]
        ms_target = [
            Data(x=x_target, pos=pos_target),
            Data(x=4 * x_target, pos=4 * pos_target)
        ]
        data2 = MultiScalePair(x=x,
                               pos=pos,
                               multiscale=ms,
                               x_target=x_target,
                               pos_target=pos_target,
                               multiscale_target=ms_target)

        batch = PairMultiScaleBatch.from_data_list([data1, data2])
        tt.assert_allclose(batch.x, torch.tensor([1, 3]))
        tt.assert_allclose(batch.x_target, torch.tensor([2, 4]))
        tt.assert_allclose(batch.batch, torch.tensor([0, 1]))
        tt.assert_allclose(batch.batch_target, torch.tensor([0, 1]))

        ms_batches = batch.multiscale
        tt.assert_allclose(ms_batches[0].batch, torch.tensor([0, 1]))
        tt.assert_allclose(ms_batches[1].batch, torch.tensor([0, 1]))
        tt.assert_allclose(ms_batches[1].x, torch.tensor([4, 12]))

        ms_batches = batch.multiscale_target
        tt.assert_allclose(ms_batches[0].batch, torch.tensor([0, 1]))
        tt.assert_allclose(ms_batches[1].batch, torch.tensor([0, 1]))
        tt.assert_allclose(ms_batches[1].x, torch.tensor([8, 16]))
    def _get_collate_function(conv_type, is_multiscale):

        is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type)

        if is_multiscale:
            if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower():
                return lambda datalist: PairMultiScaleBatch.from_data_list(datalist)
            else:
                raise NotImplementedError(
                    "MultiscaleTransform is activated and supported only for partial_dense format"
                )

        if is_dense:
            return lambda datalist: SimpleBatch.from_data_list(datalist)
        else:
            return lambda datalist: PairBatch.from_data_list(datalist)
    def test_ms_pair_ind(self):

        x = torch.randn(1001, 3)
        pos = x
        x_target = torch.randn(1452, 3)
        pos_target = x_target
        ms = [Data(x=x, pos=pos), Data(x=4 * x, pos=4 * pos)]
        ms_target = [
            Data(x=x_target, pos=pos_target),
            Data(x=4 * x_target, pos=4 * pos_target)
        ]
        data1 = MultiScalePair(x=x,
                               pos=pos,
                               multiscale=ms,
                               x_target=x_target,
                               pos_target=pos_target,
                               multiscale_target=ms_target)
        data1.pair_ind = torch.tensor([[0, 1], [99, 36], [98, 113], [54, 29],
                                       [10, 110], [1, 0]])
        x = torch.randn(300, 3)
        pos = x
        x_target = torch.randn(154, 3)
        pos_target = x_target
        ms = [Data(x=x, pos=pos), Data(x=4 * x, pos=4 * pos)]
        ms_target = [
            Data(x=x_target, pos=pos_target),
            Data(x=4 * x_target, pos=4 * pos_target)
        ]
        data2 = MultiScalePair(x=x,
                               pos=pos,
                               multiscale=ms,
                               x_target=x_target,
                               pos_target=pos_target,
                               multiscale_target=ms_target)
        data2.pair_ind = torch.tensor([[0, 1], [100, 1000], [1, 0]])

        batch = PairMultiScaleBatch.from_data_list([data1, data2])

        expected_pair_ind = torch.tensor([[0, 1], [99, 36], [98, 113],
                                          [54, 29], [10, 110], [1, 0],
                                          [0 + 1001, 1 + 1452],
                                          [100 + 1001, 1000 + 1452],
                                          [1 + 1001, 0 + 1452]])

        npt.assert_almost_equal(batch.pair_ind.numpy(),
                                expected_pair_ind.numpy())
    def __getitem__(self, index):

        if self._ms_transform:
            return PairMultiScaleBatch.from_data_list(self.datalist)
        else:
            return PairBatch.from_data_list(self.datalist)