Esempio n. 1
0
def initialize_data_loader(DatasetClass,
                           config,
                           phase,
                           threads,
                           shuffle,
                           repeat,
                           augment_data,
                           batch_size,
                           limit_numpoints,
                           elastic_distortion=False,
                           input_transform=None,
                           target_transform=None):
  if isinstance(phase, str):
    phase = str2datasetphase_type(phase)

  if config.return_transformation:
    collate_fn = t.cflt_collate_fn_factory(limit_numpoints)
  else:
    collate_fn = t.cfl_collate_fn_factory(limit_numpoints)

  input_transforms = []
  if input_transform is not None:
    input_transforms += input_transform

  if augment_data:
    input_transforms += [
        t.RandomDropout(0.2),
        t.RandomHorizontalFlip(DatasetClass.ROTATION_AXIS, DatasetClass.IS_TEMPORAL),
        t.ChromaticAutoContrast(),
        t.ChromaticTranslation(config.data_aug_color_trans_ratio),
        t.ChromaticJitter(config.data_aug_color_jitter_std),
        # t.HueSaturationTranslation(config.data_aug_hue_max, config.data_aug_saturation_max),
    ]

  if len(input_transforms) > 0:
    input_transforms = t.Compose(input_transforms)
  else:
    input_transforms = None

  dataset = DatasetClass(
      config,
      input_transform=input_transforms,
      target_transform=target_transform,
      cache=config.cache_data,
      augment_data=augment_data,
      elastic_distortion=elastic_distortion,
      phase=phase)

  if repeat:
    # Use the inf random sampler
    data_loader = DataLoader(
        dataset=dataset,
        num_workers=threads,
        batch_size=batch_size,
        collate_fn=collate_fn,
        worker_init_fn=_init_fn,
        pin_memory=True,
        sampler=InfSampler(dataset, shuffle))
  else:
    # Default shuffle=False
    data_loader = DataLoader(
        dataset=dataset,
        num_workers=threads,
        batch_size=batch_size,
        collate_fn=collate_fn,
        worker_init_fn=_init_fn,
        pin_memory=True,
        shuffle=shuffle)

  return data_loader
Esempio n. 2
0
def initialize_data_loader(DatasetClass,
                           config,
                           phase,
                           threads,
                           shuffle,
                           repeat,
                           augment_data,
                           batch_size,
                           limit_numpoints,
                           input_transform=None,
                           target_transform=None):
    if isinstance(phase, str):
        phase = str2datasetphase_type(phase)

    if config.return_transformation:
        collate_fn = cflt_collate_fn_factory(DatasetClass.IS_ROTATION_BBOX,
                                             limit_numpoints, config)
    else:
        collate_fn = cfl_collate_fn_factory(DatasetClass.IS_ROTATION_BBOX,
                                            limit_numpoints, config)

    input_transforms = []
    if input_transform is not None:
        input_transforms += input_transform

    if augment_data:
        input_transforms += [
            t.RandomHorizontalFlip(DatasetClass.ROTATION_AXIS,
                                   DatasetClass.IS_TEMPORAL),
            t.HeightTranslation(config.data_aug_height_trans_std),
            t.HeightJitter(config.data_aug_height_jitter_std),
        ]
        if DatasetClass.USE_RGB:
            input_transforms += [
                t.ChromaticTranslation(config.data_aug_color_trans_ratio),
                t.ChromaticJitter(config.data_aug_color_jitter_std),
            ]

    if len(input_transforms) > 0:
        input_transforms = t.Compose(input_transforms)
    else:
        input_transforms = None

    dataset = DatasetClass(config,
                           input_transform=input_transforms,
                           target_transform=target_transform,
                           cache=config.cache_data,
                           augment_data=augment_data,
                           phase=phase)

    if repeat:
        # Use the inf random sampler
        data_loader = DataLoader(dataset=dataset,
                                 num_workers=threads,
                                 batch_size=batch_size,
                                 collate_fn=collate_fn,
                                 sampler=InfSampler(dataset, shuffle))
    else:
        # Default shuffle=False
        data_loader = DataLoader(dataset=dataset,
                                 num_workers=threads,
                                 batch_size=batch_size,
                                 collate_fn=collate_fn,
                                 shuffle=shuffle)

    return data_loader
Esempio n. 3
0
def initialize_data_loader(DatasetClass,
                           config,
                           phase,
                           num_workers,
                           shuffle,
                           repeat,
                           augment_data,
                           batch_size,
                           limit_numpoints,
                           input_transform=None,
                           target_transform=None):
  if isinstance(phase, str):
    phase = str2datasetphase_type(phase)

  if config.return_transformation:
    collate_fn = t.cflt_collate_fn_factory(limit_numpoints)
  else:
    collate_fn = t.cfl_collate_fn_factory(limit_numpoints)

  prevoxel_transform_train = []
  if augment_data:
    prevoxel_transform_train.append(t.ElasticDistortion(DatasetClass.ELASTIC_DISTORT_PARAMS))

  if len(prevoxel_transform_train) > 0:
    prevoxel_transforms = t.Compose(prevoxel_transform_train)
  else:
    prevoxel_transforms = None

  input_transforms = []
  if input_transform is not None:
    input_transforms += input_transform

  if augment_data:
    input_transforms += [
        t.RandomDropout(0.2),
        t.RandomHorizontalFlip(DatasetClass.ROTATION_AXIS, DatasetClass.IS_TEMPORAL),
        t.ChromaticAutoContrast(),
        t.ChromaticTranslation(config.data_aug_color_trans_ratio),
        t.ChromaticJitter(config.data_aug_color_jitter_std),
        # t.HueSaturationTranslation(config.data_aug_hue_max, config.data_aug_saturation_max),
    ]

  if len(input_transforms) > 0:
    input_transforms = t.Compose(input_transforms)
  else:
    input_transforms = None

  dataset = DatasetClass(
      config,
      prevoxel_transform=prevoxel_transforms,
      input_transform=input_transforms,
      target_transform=target_transform,
      cache=config.cache_data,
      augment_data=augment_data,
      phase=phase)

  data_args = {
      'dataset': dataset,
      'num_workers': num_workers,
      'batch_size': batch_size,
      'collate_fn': collate_fn,
  }

  if repeat:
    if get_world_size() > 1:
      data_args['sampler'] = DistributedInfSampler(dataset, shuffle=shuffle)  # torch.utils.data.distributed.DistributedSampler(dataset)
    else:
      data_args['sampler'] = InfSampler(dataset, shuffle)
  
  else:
    data_args['shuffle'] = shuffle

  data_loader = DataLoader(**data_args)

  return data_loader
Esempio n. 4
0
def main():
    config = get_config()
    ch = logging.StreamHandler(sys.stdout)
    logging.getLogger().setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(
        os.path.join(config.log_dir, './model.log'))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logging.basicConfig(format=os.uname()[1].split('.')[0] +
                        ' %(asctime)s %(message)s',
                        datefmt='%m/%d %H:%M:%S',
                        handlers=[ch, file_handler])

    if config.test_config:
        # When using the test_config, reload and overwrite it, so should keep some configs
        val_bs = config.val_batch_size
        is_export = config.is_export

        json_config = json.load(open(config.test_config, 'r'))
        json_config['is_train'] = False
        json_config['weights'] = config.weights
        json_config['multiprocess'] = False
        json_config['log_dir'] = config.log_dir
        json_config['val_threads'] = config.val_threads
        json_config['submit'] = config.submit
        config = edict(json_config)

        config.val_batch_size = val_bs
        config.is_export = is_export
        config.is_train = False
        sys.path.append(config.log_dir)
        # from local_models import load_model
    else:
        '''bakup files'''
        if not os.path.exists(os.path.join(config.log_dir, 'models')):
            os.mkdir(os.path.join(config.log_dir, 'models'))
        for filename in os.listdir('./models'):
            if ".py" in filename:  # donnot cp the init file since it will raise import error
                shutil.copy(os.path.join("./models", filename),
                            os.path.join(config.log_dir, 'models'))
            elif 'modules' in filename:
                # copy the moduls folder also
                if os.path.exists(
                        os.path.join(config.log_dir, 'models/modules')):
                    shutil.rmtree(
                        os.path.join(config.log_dir, 'models/modules'))
                shutil.copytree(os.path.join('./models', filename),
                                os.path.join(config.log_dir, 'models/modules'))

        shutil.copy('./main.py', config.log_dir)
        shutil.copy('./config.py', config.log_dir)
        shutil.copy('./lib/train.py', config.log_dir)
        shutil.copy('./lib/test.py', config.log_dir)

    if config.resume == 'True':
        new_iter_size = config.max_iter
        new_bs = config.batch_size
        config.resume = config.log_dir
        json_config = json.load(open(config.resume + '/config.json', 'r'))
        json_config['resume'] = config.resume
        config = edict(json_config)
        config.weights = os.path.join(
            config.log_dir, 'weights.pth')  # use the pre-trained weights
        logging.info('==== resuming from {}, Total {} ======'.format(
            config.max_iter, new_iter_size))
        config.max_iter = new_iter_size
        config.batch_size = new_bs
    else:
        config.resume = None

    if config.is_cuda and not torch.cuda.is_available():
        raise Exception("No GPU found")
    gpu_list = range(config.num_gpu)
    device = get_torch_device(config.is_cuda)

    # torch.set_num_threads(config.threads)
    # torch.manual_seed(config.seed)
    # if config.is_cuda:
    #       torch.cuda.manual_seed(config.seed)

    logging.info('===> Configurations')
    dconfig = vars(config)
    for k in dconfig:
        logging.info('      {}: {}'.format(k, dconfig[k]))

    DatasetClass = load_dataset(config.dataset)
    logging.info('===> Initializing dataloader')

    setup_seed(2021)
    """
    ---- Setting up train, val, test dataloaders ----
    Supported datasets:
    - ScannetSparseVoxelizationDataset
    - ScannetDataset
    - SemanticKITTI
    """

    point_scannet = False
    if config.is_train:

        if config.dataset == 'ScannetSparseVoxelizationDataset':
            point_scannet = False
            train_data_loader = initialize_data_loader(
                DatasetClass,
                config,
                phase=config.train_phase,
                threads=config.threads,
                augment_data=True,
                elastic_distortion=config.train_elastic_distortion,
                shuffle=True,
                # shuffle=False,   # DEBUG ONLY!!!
                repeat=True,
                # repeat=False,
                batch_size=config.batch_size,
                limit_numpoints=config.train_limit_numpoints)

            val_data_loader = initialize_data_loader(
                DatasetClass,
                config,
                threads=config.val_threads,
                phase=config.val_phase,
                augment_data=False,
                elastic_distortion=config.test_elastic_distortion,
                shuffle=False,
                repeat=False,
                batch_size=config.val_batch_size,
                limit_numpoints=False)

        elif config.dataset == 'ScannetDataset':
            val_DatasetClass = load_dataset(
                'ScannetDatasetWholeScene_evaluation')
            point_scannet = True

            # collate_fn = t.cfl_collate_fn_factory(False) # no limit num-points
            trainset = DatasetClass(
                root=
                '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles',
                npoints=config.num_points,
                # split='debug',
                split='train',
                with_norm=False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                dataset=trainset,
                num_workers=config.threads,
                # num_workers=0,  # for loading big pth file, should use single-thread
                batch_size=config.batch_size,
                # collate_fn=collate_fn, # input points, should not have collate-fn
                worker_init_fn=_init_fn,
                sampler=InfSampler(trainset, True))  # shuffle=True

            valset = val_DatasetClass(
                root=
                '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles',
                scene_list_dir=
                '/data/eva_share_users/zhaotianchen/scannet/raw/metadata',
                # split='debug',
                split='eval',
                block_points=config.num_points,
                with_norm=False,
                delta=1.0,
            )
            val_data_loader = torch.utils.data.DataLoader(
                dataset=valset,
                # num_workers=config.threads,
                num_workers=
                0,  # for loading big pth file, should use single-thread
                batch_size=config.val_batch_size,
                # collate_fn=collate_fn, # input points, should not have collate-fn
                worker_init_fn=_init_fn)

        elif config.dataset == "SemanticKITTI":
            point_scannet = False
            dataset = SemanticKITTI(root=config.semantic_kitti_path,
                                    num_points=None,
                                    voxel_size=config.voxel_size,
                                    sample_stride=config.sample_stride,
                                    submit=False)
            collate_fn_factory = t.cfl_collate_fn_factory
            train_data_loader = torch.utils.data.DataLoader(
                dataset['train'],
                batch_size=config.batch_size,
                sampler=InfSampler(dataset['train'],
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                collate_fn=collate_fn_factory(config.train_limit_numpoints))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                dataset['test'],
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(False))
        elif config.dataset == "S3DIS":
            trainset = S3DIS(
                config,
                train=True,
            )
            valset = S3DIS(
                config,
                train=False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=config.batch_size,
                sampler=InfSampler(trainset,
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(
                    config.train_limit_numpoints))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                valset,
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(False))
        elif config.dataset == 'Nuscenes':
            config.xyz_input = False
            # todo:
            trainset = Nuscenes(
                config,
                train=True,
            )
            valset = Nuscenes(
                config,
                train=False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=config.batch_size,
                sampler=InfSampler(trainset,
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                # collate_fn=t.collate_fn_BEV,    # used when cylinder voxelize
                collate_fn=t.cfl_collate_fn_factory(False))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                valset,
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                # collate_fn=t.collate_fn_BEV,
                collate_fn=t.cfl_collate_fn_factory(False))
        else:
            print('Dataset {} not supported').format(config.dataset)
            raise NotImplementedError

        # Setting up num_in_channel and num_labels
        if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3

        num_labels = train_data_loader.dataset.NUM_LABELS

        # it = iter(train_data_loader)
        # for _ in range(100):
        # data = it.__next__()
        # print(data)

    else:  # not config.is_train

        val_DatasetClass = load_dataset('ScannetDatasetWholeScene_evaluation')

        if config.dataset == 'ScannetSparseVoxelizationDataset':

            if config.is_export:  # when export, we need to export the train results too
                train_data_loader = initialize_data_loader(
                    DatasetClass,
                    config,
                    phase=config.train_phase,
                    threads=config.threads,
                    augment_data=True,
                    elastic_distortion=config.
                    train_elastic_distortion,  # DEBUG: not sure about this
                    shuffle=False,
                    repeat=False,
                    batch_size=config.batch_size,
                    limit_numpoints=config.train_limit_numpoints)

                # the valid like, no aug data
                # train_data_loader = initialize_data_loader(
                # DatasetClass,
                # config,
                # threads=config.val_threads,
                # phase=config.train_phase,
                # augment_data=False,
                # elastic_distortion=config.test_elastic_distortion,
                # shuffle=False,
                # repeat=False,
                # batch_size=config.val_batch_size,
                # limit_numpoints=False)

            val_data_loader = initialize_data_loader(
                DatasetClass,
                config,
                threads=config.val_threads,
                phase=config.val_phase,
                augment_data=False,
                elastic_distortion=config.test_elastic_distortion,
                shuffle=False,
                repeat=False,
                batch_size=config.val_batch_size,
                limit_numpoints=False)

            if val_data_loader.dataset.NUM_IN_CHANNEL is not None:
                num_in_channel = val_data_loader.dataset.NUM_IN_CHANNEL
            else:
                num_in_channel = 3

            num_labels = val_data_loader.dataset.NUM_LABELS

        elif config.dataset == 'ScannetDataset':
            '''when using scannet-point, use val instead of test'''

            point_scannet = True
            valset = val_DatasetClass(
                root=
                '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles',
                scene_list_dir=
                '/data/eva_share_users/zhaotianchen/scannet/raw/metadata',
                split='eval',
                block_points=config.num_points,
                delta=1.0,
                with_norm=False,
            )
            val_data_loader = torch.utils.data.DataLoader(
                dataset=valset,
                # num_workers=config.threads,
                num_workers=
                0,  # for loading big pth file, should use single-thread
                batch_size=config.val_batch_size,
                # collate_fn=collate_fn, # input points, should not have collate-fn
                worker_init_fn=_init_fn,
            )

            num_labels = val_data_loader.dataset.NUM_LABELS
            num_in_channel = 3

        elif config.dataset == "SemanticKITTI":
            dataset = SemanticKITTI(root=config.semantic_kitti_path,
                                    num_points=None,
                                    voxel_size=config.voxel_size,
                                    submit=config.submit)
            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                dataset['test'],
                batch_size=config.val_batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(False))
            num_in_channel = 4
            num_labels = 19

        elif config.dataset == 'S3DIS':
            config.xyz_input = False

            trainset = S3DIS(
                config,
                train=True,
            )
            valset = S3DIS(
                config,
                train=False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=config.batch_size,
                sampler=InfSampler(trainset,
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(
                    config.train_limit_numpoints))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                valset,
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(False))
            num_in_channel = 9
            num_labels = 13
        elif config.dataset == 'Nuscenes':
            config.xyz_input = False
            trainset = Nuscenes(
                config,
                train=True,
            )
            valset = Nuscenes(
                config,
                train - False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=config.batch_size,
                sampler=InfSampler(trainset,
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                # collate_fn=t.collate_fn_BEV,
                collate_fn=t.cfl_collate_fn_factory(False))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                valset,
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                # collate_fn=t.collate_fn_BEV,
                collate_fn=t.cfl_collate_fn_factory(False))
            num_in_channel = 5
            num_labels = 16
        else:
            print('Dataset {} not supported').format(config.dataset)
            raise NotImplementedError

    logging.info('===> Building model')

    # if config.model == 'PointTransformer' or config.model == 'MixedTransformer':
    if config.model == 'PointTransformer':
        config.pure_point = True

    NetClass = load_model(config.model)
    if config.pure_point:
        model = NetClass(config,
                         num_class=num_labels,
                         N=config.num_points,
                         normal_channel=num_in_channel)
    else:
        if config.model == 'MixedTransformer':
            model = NetClass(config,
                             num_class=num_labels,
                             N=config.num_points,
                             normal_channel=num_in_channel)
        elif config.model == 'MinkowskiVoxelTransformer':
            model = NetClass(config, num_in_channel, num_labels)
        elif config.model == 'MinkowskiTransformerNet':
            model = NetClass(config, num_in_channel, num_labels)
        elif "Res" in config.model:
            model = NetClass(num_in_channel, num_labels, config)
        else:
            model = NetClass(num_in_channel, num_labels, config)

    logging.info('===> Number of trainable parameters: {}: {}M'.format(
        NetClass.__name__,
        count_parameters(model) / 1e6))
    if hasattr(model, "block1"):
        if hasattr(model.block1[0], 'h'):
            h = model.block1[0].h
            vec_dim = model.block1[0].vec_dim
        else:
            h = None
            vec_dim = None
    else:
        h = None
        vec_dim = None
    # logging.info('===> Model Args:\n PLANES: {} \n LAYERS: {}\n HEADS: {}\n Vec-dim: {}\n'.format(model.PLANES, model.LAYERS, h, vec_dim))
    logging.info(model)

    # Set the number of threads
    # ME.initialize_nthreads(12, D=3)

    model = model.to(device)

    if config.weights == 'modelzoo':  # Load modelzoo weights if possible.
        logging.info('===> Loading modelzoo weights')
        model.preload_modelzoo()
    # Load weights if specified by the parameter.
    elif config.weights.lower() != 'none':
        logging.info('===> Loading weights: ' + config.weights)
        state = torch.load(config.weights)
        # delete the keys containing the 'attn' since it raises size mismatch
        d_ = {
            k: v
            for k, v in state['state_dict'].items() if '_map' not in k
        }  # debug: sometiems model conmtains 'map_qk' which is not right for naming a module, since 'map' are always buffers
        d = {}
        for k in d_.keys():
            if 'module.' in k:
                d[k.replace('module.', '')] = d_[k]
            else:
                d[k] = d_[k]
        # del d_

        if config.weights_for_inner_model:
            model.model.load_state_dict(d)
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                model.load_state_dict(d, strict=True)

    if config.is_debug:
        check_data(model, train_data_loader, val_data_loader, config)
        return None
    elif config.is_train:
        if hasattr(config, 'distill') and config.distill:
            assert point_scannet is not True  # only support whole scene for no
            train_distill(model, train_data_loader, val_data_loader, config)
        if config.multiprocess:
            if point_scannet:
                raise NotImplementedError
            else:
                train_mp(NetClass, train_data_loader, val_data_loader, config)
        else:
            if point_scannet:
                train_point(model, train_data_loader, val_data_loader, config)
            else:
                train(model, train_data_loader, val_data_loader, config)
    elif config.is_export:
        if point_scannet:
            raise NotImplementedError
        else:  # only support the whole-scene-style for now
            test(model,
                 train_data_loader,
                 config,
                 save_pred=True,
                 split='train')
            test(model, val_data_loader, config, save_pred=True, split='val')
    else:
        assert config.multiprocess == False
        # if test for submission, make a submit directory at current directory
        submit_dir = os.path.join(os.getcwd(), 'submit', 'sequences')
        if config.submit and not os.path.exists(submit_dir):
            os.makedirs(submit_dir)
            print("Made submission directory: " + submit_dir)
        if point_scannet:
            test_points(model, val_data_loader, config)
        else:
            test(model, val_data_loader, config, submit_dir=submit_dir)