コード例 #1
0
 def from_data_list(data_list):
     r"""
     from a list of torch_points3d.datasets.registation.pair.Pair objects, create
     a batch
     Warning : follow_batch is not here yet...
     """
     data_list_s, data_list_t = list(map(list, zip(*[data.to_data() for data in data_list])))
     batch_s = MultiScaleBatch.from_data_list(data_list_s)
     batch_t = MultiScaleBatch.from_data_list(data_list_t)
     return PairMultiScaleBatch.make_pair(batch_s, batch_t).contiguous()
コード例 #2
0
 def from_data_list(data_list):
     r"""
     from a list of torch_points3d.datasets.registation.pair.Pair objects, create
     a batch
     Warning : follow_batch is not here yet...
     """
     data_list_s, data_list_t = list(
         map(list, zip(*[data.to_data() for data in data_list])))
     if hasattr(data_list_s[0], 'pair_ind'):
         pair_ind = concatenate_pair_ind(data_list_s,
                                         data_list_t).to(torch.long)
     else:
         pair_ind = None
     batch_s = MultiScaleBatch.from_data_list(data_list_s)
     batch_t = MultiScaleBatch.from_data_list(data_list_t)
     pair = PairMultiScaleBatch.make_pair(batch_s, batch_t)
     pair.pair_ind = pair_ind
     return pair.contiguous()
コード例 #3
0
ファイル: base_dataset.py プロジェクト: RJ2019/torch-points3d
    def _get_collate_function(conv_type, is_multiscale):
        if is_multiscale:
            if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower():
                return lambda datalist: MultiScaleBatch.from_data_list(datalist)
            else:
                raise NotImplementedError(
                    "MultiscaleTransform is activated and supported only for partial_dense format"
                )

        is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type)
        if is_dense:
            return lambda datalist: SimpleBatch.from_data_list(datalist)
        else:
            return lambda datalist: torch_geometric.data.batch.Batch.from_data_list(datalist)
コード例 #4
0
    def test_batch(self):
        x = torch.tensor([1])
        pos = x
        d1 = Data(x=x, pos=pos)
        d2 = Data(x=4 * x, pos=4 * pos)
        data1 = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2])

        x = torch.tensor([2])
        pos = x
        d1 = Data(x=x, pos=pos)
        d2 = Data(x=4 * x, pos=4 * pos)
        data2 = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2])

        batch = MultiScaleBatch.from_data_list([data1, data2])
        tt.assert_allclose(batch.x, torch.tensor([1, 2]))
        tt.assert_allclose(batch.batch, torch.tensor([0, 1]))

        ms_batches = batch.multiscale
        tt.assert_allclose(ms_batches[0].batch, torch.tensor([0, 1]))
        tt.assert_allclose(ms_batches[1].batch, torch.tensor([0, 1]))
        tt.assert_allclose(ms_batches[1].x, torch.tensor([4, 8]))
コード例 #5
0
 def __getitem__(self, index):
     if self._ms_transform:
         return MultiScaleBatch.from_data_list(self.datalist)
     else:
         return Batch.from_data_list(self.datalist)