def test_pair_ind(self): data1 = Data(pos=torch.randn(100, 3)) data2 = Data(pos=torch.randn(114, 3)) pair1 = Pair.make_pair(data1, data2) pair1.pair_ind = torch.tensor([[0, 1], [99, 36], [98, 113], [54, 29], [10, 110], [1, 0]]) data3 = Data(pos=torch.randn(102, 3)) data4 = Data(pos=torch.randn(104, 3)) pair2 = Pair.make_pair(data3, data4) pair2.pair_ind = torch.tensor([[0, 1], [45, 28], [101, 36], [98, 1], [14, 99], [34, 52], [1, 0]]) data5 = Data(pos=torch.randn(128, 3)) data6 = Data(pos=torch.randn(2102, 3)) pair3 = Pair.make_pair(data5, data6) pair3.pair_ind = torch.tensor([[0, 1], [100, 1000], [1, 0]]) batch = PairBatch.from_data_list([pair1, pair2, pair3]) expected_pair_ind = torch.tensor([[0, 1], [99, 36], [98, 113], [54, 29], [10, 110], [1, 0], [0 + 100, 1 + 114], [45 + 100, 28 + 114], [101 + 100, 36 + 114], [98 + 100, 1 + 114], [14 + 100, 99 + 114], [34 + 100, 52 + 114], [1 + 100, 0 + 114], [0 + 100 + 102, 1 + 114 + 104], [100 + 100 + 102, 1000 + 114 + 104], [1 + 100 + 102, 0 + 114 + 104]]).to(torch.long) npt.assert_almost_equal(batch.pair_ind.numpy(), expected_pair_ind.numpy())
def test_pair_batch(self): d1 = Data(x=torch.tensor([1]), pos=torch.tensor([1])) d2 = Data(x=torch.tensor([2]), pos=torch.tensor([4])) d3 = Data(x=torch.tensor([3]), pos=torch.tensor([9])) d4 = Data(x=torch.tensor([4]), pos=torch.tensor([16])) p1 = Pair.make_pair(d1, d2) p2 = Pair.make_pair(d3, d4) batch = PairBatch.from_data_list([p1, p2]) tt.assert_allclose(batch.x, torch.tensor([1, 3])) tt.assert_allclose(batch.pos, torch.tensor([1, 9])) tt.assert_allclose(batch.batch, torch.tensor([0, 1])) tt.assert_allclose(batch.x_target, torch.tensor([2, 4])) tt.assert_allclose(batch.pos_target, torch.tensor([4, 16])) tt.assert_allclose(batch.batch_target, torch.tensor([0, 1]))
def _get_collate_function(conv_type, is_multiscale): is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type) if is_multiscale: if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower(): return lambda datalist: PairMultiScaleBatch.from_data_list(datalist) else: raise NotImplementedError( "MultiscaleTransform is activated and supported only for partial_dense format" ) if is_dense: return lambda datalist: SimpleBatch.from_data_list(datalist) else: return lambda datalist: PairBatch.from_data_list(datalist)
def __getitem__(self, index): if self._ms_transform: return PairMultiScaleBatch.from_data_list(self.datalist) else: return PairBatch.from_data_list(self.datalist)