def from_data_list(data_list): r""" from a list of torch_points3d.datasets.registation.pair.Pair objects, create a batch Warning : follow_batch is not here yet... """ data_list_s, data_list_t = list(map(list, zip(*[data.to_data() for data in data_list]))) batch_s = MultiScaleBatch.from_data_list(data_list_s) batch_t = MultiScaleBatch.from_data_list(data_list_t) return PairMultiScaleBatch.make_pair(batch_s, batch_t).contiguous()
def from_data_list(data_list): r""" from a list of torch_points3d.datasets.registation.pair.Pair objects, create a batch Warning : follow_batch is not here yet... """ data_list_s, data_list_t = list( map(list, zip(*[data.to_data() for data in data_list]))) if hasattr(data_list_s[0], 'pair_ind'): pair_ind = concatenate_pair_ind(data_list_s, data_list_t).to(torch.long) else: pair_ind = None batch_s = MultiScaleBatch.from_data_list(data_list_s) batch_t = MultiScaleBatch.from_data_list(data_list_t) pair = PairMultiScaleBatch.make_pair(batch_s, batch_t) pair.pair_ind = pair_ind return pair.contiguous()
def _get_collate_function(conv_type, is_multiscale): if is_multiscale: if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower(): return lambda datalist: MultiScaleBatch.from_data_list(datalist) else: raise NotImplementedError( "MultiscaleTransform is activated and supported only for partial_dense format" ) is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type) if is_dense: return lambda datalist: SimpleBatch.from_data_list(datalist) else: return lambda datalist: torch_geometric.data.batch.Batch.from_data_list(datalist)
def test_batch(self): x = torch.tensor([1]) pos = x d1 = Data(x=x, pos=pos) d2 = Data(x=4 * x, pos=4 * pos) data1 = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2]) x = torch.tensor([2]) pos = x d1 = Data(x=x, pos=pos) d2 = Data(x=4 * x, pos=4 * pos) data2 = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2]) batch = MultiScaleBatch.from_data_list([data1, data2]) tt.assert_allclose(batch.x, torch.tensor([1, 2])) tt.assert_allclose(batch.batch, torch.tensor([0, 1])) ms_batches = batch.multiscale tt.assert_allclose(ms_batches[0].batch, torch.tensor([0, 1])) tt.assert_allclose(ms_batches[1].batch, torch.tensor([0, 1])) tt.assert_allclose(ms_batches[1].x, torch.tensor([4, 8]))
def __getitem__(self, index): if self._ms_transform: return MultiScaleBatch.from_data_list(self.datalist) else: return Batch.from_data_list(self.datalist)