예제 #1
0
 def get_sample(batch, key, index, conv_type):
     assert hasattr(batch, key)
     is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type)
     if is_dense:
         return batch[key][index]
     else:
         return batch[key][batch.batch == index]
예제 #2
0
    def __init__(self, opt, type, dataset, modules_lib):
        super().__init__(opt)
        self.pointnet_seg = PointNetSeg(**flatten_dict(opt))
        self._is_dense = ConvolutionFormatFactory.check_is_dense_format(
            self.conv_type)

        self.visual_names = ["data_visual"]
예제 #3
0
    def __init__(self, opt, model_type=None, dataset=None, modules=None):
        super().__init__(opt)

        self._opt = OmegaConf.to_container(opt)
        self._is_dense = ConvolutionFormatFactory.check_is_dense_format(
            self.conv_type)

        self.visual_names = ["data_visual"]
예제 #4
0
 def _get_collate_function(conv_type, is_multiscale):
     if is_multiscale:
         if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower():
             return MultiScaleBatch.from_data_list
         else:
             raise NotImplementedError(
                 "MultiscaleTransform is activated and supported only for partial_dense format"
             )
     is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type)
     if is_dense:
         return SimpleBatch.from_data_list
     else:
         return torch_geometric.data.batch.Batch.from_data_list
 def _get_collate_function(conv_type, is_multiscale, pre_collate_transform=None):
     is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type)
     if is_multiscale:
         if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower():
             fn = MultiScaleBatch.from_data_list
         else:
             raise NotImplementedError(
                 "MultiscaleTransform is activated and supported only for partial_dense format"
             )
     else:
         if is_dense:
             fn = SimpleBatch.from_data_list
         else:
             fn = torch_geometric.data.batch.Batch.from_data_list
     return partial(BaseDataset._collate_fn, collate_fn=fn, pre_collate_transform=pre_collate_transform)
예제 #6
0
    def _get_collate_function(conv_type, is_multiscale):

        is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type)

        if is_multiscale:
            if conv_type.lower(
            ) == ConvolutionFormat.PARTIAL_DENSE.value.lower():
                return lambda datalist: PairMultiScaleBatch.from_data_list(
                    datalist)
            else:
                raise NotImplementedError(
                    "MultiscaleTransform is activated and supported only for partial_dense format"
                )

        if is_dense:
            return lambda datalist: DensePairBatch.from_data_list(datalist)
        else:
            return lambda datalist: PairBatch.from_data_list(datalist)
예제 #7
0
 def get_num_samples(batch, conv_type):
     is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type)
     if is_dense:
         return batch.pos.shape[0]
     else:
         return batch.batch.max() + 1