def test_ms_pair_batch(self): x = torch.tensor([1]) pos = x x_target = torch.tensor([2]) pos_target = x ms = [Data(x=x, pos=pos), Data(x=4 * x, pos=4 * pos)] ms_target = [ Data(x=x_target, pos=pos_target), Data(x=4 * x_target, pos=4 * pos_target) ] data1 = MultiScalePair(x=x, pos=pos, multiscale=ms, x_target=x_target, pos_target=pos_target, multiscale_target=ms_target) x = torch.tensor([3]) pos = x x_target = torch.tensor([4]) pos_target = x ms = [Data(x=x, pos=pos), Data(x=4 * x, pos=4 * pos)] ms_target = [ Data(x=x_target, pos=pos_target), Data(x=4 * x_target, pos=4 * pos_target) ] data2 = MultiScalePair(x=x, pos=pos, multiscale=ms, x_target=x_target, pos_target=pos_target, multiscale_target=ms_target) batch = PairMultiScaleBatch.from_data_list([data1, data2]) tt.assert_allclose(batch.x, torch.tensor([1, 3])) tt.assert_allclose(batch.x_target, torch.tensor([2, 4])) tt.assert_allclose(batch.batch, torch.tensor([0, 1])) tt.assert_allclose(batch.batch_target, torch.tensor([0, 1])) ms_batches = batch.multiscale tt.assert_allclose(ms_batches[0].batch, torch.tensor([0, 1])) tt.assert_allclose(ms_batches[1].batch, torch.tensor([0, 1])) tt.assert_allclose(ms_batches[1].x, torch.tensor([4, 12])) ms_batches = batch.multiscale_target tt.assert_allclose(ms_batches[0].batch, torch.tensor([0, 1])) tt.assert_allclose(ms_batches[1].batch, torch.tensor([0, 1])) tt.assert_allclose(ms_batches[1].x, torch.tensor([8, 16]))
def _get_collate_function(conv_type, is_multiscale): is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type) if is_multiscale: if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower(): return lambda datalist: PairMultiScaleBatch.from_data_list(datalist) else: raise NotImplementedError( "MultiscaleTransform is activated and supported only for partial_dense format" ) if is_dense: return lambda datalist: SimpleBatch.from_data_list(datalist) else: return lambda datalist: PairBatch.from_data_list(datalist)
def test_ms_pair_ind(self): x = torch.randn(1001, 3) pos = x x_target = torch.randn(1452, 3) pos_target = x_target ms = [Data(x=x, pos=pos), Data(x=4 * x, pos=4 * pos)] ms_target = [ Data(x=x_target, pos=pos_target), Data(x=4 * x_target, pos=4 * pos_target) ] data1 = MultiScalePair(x=x, pos=pos, multiscale=ms, x_target=x_target, pos_target=pos_target, multiscale_target=ms_target) data1.pair_ind = torch.tensor([[0, 1], [99, 36], [98, 113], [54, 29], [10, 110], [1, 0]]) x = torch.randn(300, 3) pos = x x_target = torch.randn(154, 3) pos_target = x_target ms = [Data(x=x, pos=pos), Data(x=4 * x, pos=4 * pos)] ms_target = [ Data(x=x_target, pos=pos_target), Data(x=4 * x_target, pos=4 * pos_target) ] data2 = MultiScalePair(x=x, pos=pos, multiscale=ms, x_target=x_target, pos_target=pos_target, multiscale_target=ms_target) data2.pair_ind = torch.tensor([[0, 1], [100, 1000], [1, 0]]) batch = PairMultiScaleBatch.from_data_list([data1, data2]) expected_pair_ind = torch.tensor([[0, 1], [99, 36], [98, 113], [54, 29], [10, 110], [1, 0], [0 + 1001, 1 + 1452], [100 + 1001, 1000 + 1452], [1 + 1001, 0 + 1452]]) npt.assert_almost_equal(batch.pair_ind.numpy(), expected_pair_ind.numpy())
def __getitem__(self, index): if self._ms_transform: return PairMultiScaleBatch.from_data_list(self.datalist) else: return PairBatch.from_data_list(self.datalist)