def main(): config = get_config() if config.resume: json_config = json.load(open(config.resume + '/config.json', 'r')) json_config['resume'] = config.resume config = edict(json_config) if config.is_cuda and not torch.cuda.is_available(): raise Exception("No GPU found") device = get_torch_device(config.is_cuda) logging.info('===> Configurations') dconfig = vars(config) for k in dconfig: logging.info(' {}: {}'.format(k, dconfig[k])) DatasetClass = load_dataset(config.dataset) if config.test_original_pointcloud: if not DatasetClass.IS_FULL_POINTCLOUD_EVAL: raise ValueError( 'This dataset does not support full pointcloud evaluation.') if config.evaluate_original_pointcloud: if not config.return_transformation: raise ValueError( 'Pointcloud evaluation requires config.return_transformation=true.' ) if (config.return_transformation ^ config.evaluate_original_pointcloud): raise ValueError( 'Rotation evaluation requires config.evaluate_original_pointcloud=true and ' 'config.return_transformation=true.') logging.info('===> Initializing dataloader') if config.is_train: train_data_loader = initialize_data_loader( DatasetClass, config, phase=config.train_phase, threads=config.threads, augment_data=True, shuffle=True, repeat=True, batch_size=config.batch_size, limit_numpoints=config.train_limit_numpoints) val_data_loader = initialize_data_loader( DatasetClass, config, threads=config.val_threads, phase=config.val_phase, augment_data=False, shuffle=True, repeat=False, batch_size=config.val_batch_size, limit_numpoints=False) if train_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 # RGB color num_labels = train_data_loader.dataset.NUM_LABELS else: test_data_loader = initialize_data_loader( DatasetClass, config, threads=config.threads, phase=config.test_phase, augment_data=False, shuffle=False, repeat=False, batch_size=config.test_batch_size, limit_numpoints=False) if test_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 # RGB color num_labels = test_data_loader.dataset.NUM_LABELS logging.info('===> Building model') NetClass = load_model(config.model) if config.wrapper_type == 'None': model = NetClass(num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}'.format( NetClass.__name__, count_parameters(model))) else: wrapper = load_wrapper(config.wrapper_type) model = wrapper(NetClass, num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}'.format( wrapper.__name__ + NetClass.__name__, count_parameters(model))) logging.info(model) model = model.to(device) if config.weights == 'modelzoo': # Load modelzoo weights if possible. logging.info('===> Loading modelzoo weights') model.preload_modelzoo() # Load weights if specified by the parameter. elif config.weights.lower() != 'none': logging.info('===> Loading weights: ' + config.weights) state = torch.load(config.weights) if config.weights_for_inner_model: model.model.load_state_dict(state['state_dict']) else: if config.lenient_weight_loading: matched_weights = load_state_with_same_shape( model, state['state_dict']) model_dict = model.state_dict() model_dict.update(matched_weights) model.load_state_dict(model_dict) else: model.load_state_dict(state['state_dict']) if config.is_train: train(model, train_data_loader, val_data_loader, config) else: test(model, test_data_loader, config)
def main(config, init_distributed=False): if not torch.cuda.is_available(): raise Exception('No GPUs FOUND.') # setup initial seed torch.cuda.set_device(config.device_id) torch.manual_seed(config.seed) torch.cuda.manual_seed(config.seed) device = config.device_id distributed = config.distributed_world_size > 1 if init_distributed: config.distributed_rank = distributed_utils.distributed_init(config) setup_logging(config) logging.info('===> Configurations') dconfig = vars(config) for k in dconfig: logging.info(' {}: {}'.format(k, dconfig[k])) DatasetClass = load_dataset(config.dataset) if config.test_original_pointcloud: if not DatasetClass.IS_FULL_POINTCLOUD_EVAL: raise ValueError( 'This dataset does not support full pointcloud evaluation.') if config.evaluate_original_pointcloud: if not config.return_transformation: raise ValueError( 'Pointcloud evaluation requires config.return_transformation=true.' ) if (config.return_transformation ^ config.evaluate_original_pointcloud): raise ValueError( 'Rotation evaluation requires config.evaluate_original_pointcloud=true and ' 'config.return_transformation=true.') logging.info('===> Initializing dataloader') if config.is_train: train_data_loader = initialize_data_loader( DatasetClass, config, phase=config.train_phase, num_workers=config.num_workers, augment_data=True, shuffle=True, repeat=True, batch_size=config.batch_size, limit_numpoints=config.train_limit_numpoints) val_data_loader = initialize_data_loader( DatasetClass, config, num_workers=config.num_val_workers, phase=config.val_phase, augment_data=False, shuffle=True, repeat=False, batch_size=config.val_batch_size, limit_numpoints=False) if train_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 # RGB color num_labels = train_data_loader.dataset.NUM_LABELS else: test_data_loader = initialize_data_loader( DatasetClass, config, num_workers=config.num_workers, phase=config.test_phase, augment_data=False, shuffle=False, repeat=False, batch_size=config.test_batch_size, limit_numpoints=False) if test_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 # RGB color num_labels = test_data_loader.dataset.NUM_LABELS logging.info('===> Building model') NetClass = load_model(config.model) if config.wrapper_type == 'None': model = NetClass(num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}'.format( NetClass.__name__, count_parameters(model))) else: wrapper = load_wrapper(config.wrapper_type) model = wrapper(NetClass, num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}'.format( wrapper.__name__ + NetClass.__name__, count_parameters(model))) logging.info(model) if config.weights == 'modelzoo': # Load modelzoo weights if possible. logging.info('===> Loading modelzoo weights') model.preload_modelzoo() # Load weights if specified by the parameter. elif config.weights.lower() != 'none': logging.info('===> Loading weights: ' + config.weights) # state = torch.load(config.weights) state = torch.load( config.weights, map_location=lambda s, l: default_restore_location(s, 'cpu')) if config.weights_for_inner_model: model.model.load_state_dict(state['state_dict']) else: if config.lenient_weight_loading: matched_weights = load_state_with_same_shape( model, state['state_dict']) model_dict = model.state_dict() model_dict.update(matched_weights) model.load_state_dict(model_dict) else: model.load_state_dict(state['state_dict']) model = model.cuda() if distributed: model = torch.nn.parallel.DistributedDataParallel( module=model, device_ids=[device], output_device=device, broadcast_buffers=False, bucket_cap_mb=config.bucket_cap_mb) if config.is_train: train(model, train_data_loader, val_data_loader, config) else: test(model, test_data_loader, config)
def main(): config = get_config() if config.test_config: json_config = json.load(open(config.test_config, 'r')) json_config['is_train'] = False json_config['weights'] = config.weights config = edict(json_config) elif config.resume: json_config = json.load(open(config.resume + '/config.json', 'r')) json_config['resume'] = config.resume config = edict(json_config) if config.is_cuda and not torch.cuda.is_available(): raise Exception("No GPU found") device = get_torch_device(config.is_cuda) # torch.set_num_threads(config.threads) # torch.manual_seed(config.seed) # if config.is_cuda: # torch.cuda.manual_seed(config.seed) logging.info('===> Configurations') dconfig = vars(config) for k in dconfig: logging.info(' {}: {}'.format(k, dconfig[k])) DatasetClass = load_dataset(config.dataset) logging.info('===> Initializing dataloader') if config.is_train: setup_seed(2021) train_data_loader = initialize_data_loader( DatasetClass, config, phase=config.train_phase, # threads=config.threads, threads=4, augment_data=True, elastic_distortion=config.train_elastic_distortion, # elastic_distortion=False, # shuffle=True, shuffle=False, # repeat=True, repeat=False, batch_size=config.batch_size, # batch_size=8, limit_numpoints=config.train_limit_numpoints) # dat = iter(train_data_loader).__next__() # import ipdb; ipdb.set_trace() val_data_loader = initialize_data_loader( DatasetClass, config, # threads=0, threads=config.val_threads, phase=config.val_phase, augment_data=False, elastic_distortion=config.test_elastic_distortion, shuffle=False, repeat=False, # batch_size=config.val_batch_size, batch_size=8, limit_numpoints=False) # dat = iter(val_data_loader).__next__() # import ipdb; ipdb.set_trace() if train_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 num_labels = train_data_loader.dataset.NUM_LABELS else: test_data_loader = initialize_data_loader( DatasetClass, config, threads=config.threads, phase=config.test_phase, augment_data=False, elastic_distortion=config.test_elastic_distortion, shuffle=False, repeat=False, batch_size=config.test_batch_size, limit_numpoints=False) if test_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 num_labels = test_data_loader.dataset.NUM_LABELS logging.info('===> Building model') NetClass = load_model(config.model) model = NetClass(num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}'.format( NetClass.__name__, count_parameters(model))) logging.info(model) # Set the number of threads # ME.initialize_nthreads(12, D=3) model = model.to(device) if config.weights == 'modelzoo': # Load modelzoo weights if possible. logging.info('===> Loading modelzoo weights') model.preload_modelzoo() # Load weights if specified by the parameter. elif config.weights.lower() != 'none': logging.info('===> Loading weights: ' + config.weights) state = torch.load(config.weights) if config.weights_for_inner_model: model.model.load_state_dict(state['state_dict']) else: if config.lenient_weight_loading: matched_weights = load_state_with_same_shape( model, state['state_dict']) model_dict = model.state_dict() model_dict.update(matched_weights) model.load_state_dict(model_dict) else: model.load_state_dict(state['state_dict']) if config.is_train: train(model, train_data_loader, val_data_loader, config) else: test(model, test_data_loader, config)
def main_worker(gpu, ngpus_per_node, config): config.gpu = gpu #if config.is_cuda and not torch.cuda.is_available(): # raise Exception("No GPU found") if config.gpu is not None: print("Use GPU: {} for training".format(config.gpu)) device = get_torch_device(config.is_cuda) if config.distributed: if config.dist_url == "env://" and config.rank == -1: config.rank = int(os.environ["RANK"]) if config.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes config.rank = config.rank * ngpus_per_node + gpu dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url, world_size=config.world_size, rank=config.rank) logging.info('===> Configurations') dconfig = vars(config) for k in dconfig: logging.info(' {}: {}'.format(k, dconfig[k])) DatasetClass = load_dataset(config.dataset) if config.test_original_pointcloud: if not DatasetClass.IS_FULL_POINTCLOUD_EVAL: raise ValueError( 'This dataset does not support full pointcloud evaluation.') if config.evaluate_original_pointcloud: if not config.return_transformation: raise ValueError( 'Pointcloud evaluation requires config.return_transformation=true.' ) if (config.return_transformation ^ config.evaluate_original_pointcloud): raise ValueError( 'Rotation evaluation requires config.evaluate_original_pointcloud=true and ' 'config.return_transformation=true.') logging.info('===> Initializing dataloader') if config.is_train: train_data_loader, train_sampler = initialize_data_loader( DatasetClass, config, phase=config.train_phase, num_workers=config.num_workers, augment_data=True, shuffle=True, repeat=True, batch_size=config.batch_size, limit_numpoints=config.train_limit_numpoints) val_data_loader, val_sampler = initialize_data_loader( DatasetClass, config, num_workers=config.num_val_workers, phase=config.val_phase, augment_data=False, shuffle=True, repeat=False, batch_size=config.val_batch_size, limit_numpoints=False) if train_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 # RGB color num_labels = train_data_loader.dataset.NUM_LABELS else: test_data_loader, val_sampler = initialize_data_loader( DatasetClass, config, num_workers=config.num_workers, phase=config.test_phase, augment_data=False, shuffle=False, repeat=False, batch_size=config.test_batch_size, limit_numpoints=False) if test_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 # RGB color num_labels = test_data_loader.dataset.NUM_LABELS logging.info('===> Building model') NetClass = load_model(config.model) if config.wrapper_type == 'None': model = NetClass(num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}'.format( NetClass.__name__, count_parameters(model))) else: wrapper = load_wrapper(config.wrapper_type) model = wrapper(NetClass, num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}'.format( wrapper.__name__ + NetClass.__name__, count_parameters(model))) logging.info(model) if config.weights == 'modelzoo': # Load modelzoo weights if possible. logging.info('===> Loading modelzoo weights') model.preload_modelzoo() # Load weights if specified by the parameter. elif config.weights.lower() != 'none': logging.info('===> Loading weights: ' + config.weights) state = torch.load(config.weights) if config.weights_for_inner_model: model.model.load_state_dict(state['state_dict']) else: if config.lenient_weight_loading: matched_weights = load_state_with_same_shape( model, state['state_dict']) model_dict = model.state_dict() model_dict.update(matched_weights) model.load_state_dict(model_dict) else: init_model_from_weights(model, state, freeze_bb=False) if config.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if config.gpu is not None: torch.cuda.set_device(config.gpu) model.cuda(config.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have config.batch_size = int(config.batch_size / ngpus_per_node) config.num_workers = int( (config.num_workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[config.gpu]) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) if config.is_train: train(model, train_data_loader, val_data_loader, config, train_sampler=train_sampler, ngpus_per_node=ngpus_per_node) else: test(model, test_data_loader, config)
def main(): config = get_config() ch = logging.StreamHandler(sys.stdout) logging.getLogger().setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler( os.path.join(config.log_dir, './model.log')) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logging.basicConfig(format=os.uname()[1].split('.')[0] + ' %(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch, file_handler]) if config.test_config: # When using the test_config, reload and overwrite it, so should keep some configs val_bs = config.val_batch_size is_export = config.is_export json_config = json.load(open(config.test_config, 'r')) json_config['is_train'] = False json_config['weights'] = config.weights json_config['multiprocess'] = False json_config['log_dir'] = config.log_dir json_config['val_threads'] = config.val_threads json_config['submit'] = config.submit config = edict(json_config) config.val_batch_size = val_bs config.is_export = is_export config.is_train = False sys.path.append(config.log_dir) # from local_models import load_model else: '''bakup files''' if not os.path.exists(os.path.join(config.log_dir, 'models')): os.mkdir(os.path.join(config.log_dir, 'models')) for filename in os.listdir('./models'): if ".py" in filename: # donnot cp the init file since it will raise import error shutil.copy(os.path.join("./models", filename), os.path.join(config.log_dir, 'models')) elif 'modules' in filename: # copy the moduls folder also if os.path.exists( os.path.join(config.log_dir, 'models/modules')): shutil.rmtree( os.path.join(config.log_dir, 'models/modules')) shutil.copytree(os.path.join('./models', filename), os.path.join(config.log_dir, 'models/modules')) shutil.copy('./main.py', config.log_dir) shutil.copy('./config.py', config.log_dir) shutil.copy('./lib/train.py', config.log_dir) shutil.copy('./lib/test.py', config.log_dir) if config.resume == 'True': new_iter_size = config.max_iter new_bs = config.batch_size config.resume = config.log_dir json_config = json.load(open(config.resume + '/config.json', 'r')) json_config['resume'] = config.resume config = edict(json_config) config.weights = os.path.join( config.log_dir, 'weights.pth') # use the pre-trained weights logging.info('==== resuming from {}, Total {} ======'.format( config.max_iter, new_iter_size)) config.max_iter = new_iter_size config.batch_size = new_bs else: config.resume = None if config.is_cuda and not torch.cuda.is_available(): raise Exception("No GPU found") gpu_list = range(config.num_gpu) device = get_torch_device(config.is_cuda) # torch.set_num_threads(config.threads) # torch.manual_seed(config.seed) # if config.is_cuda: # torch.cuda.manual_seed(config.seed) logging.info('===> Configurations') dconfig = vars(config) for k in dconfig: logging.info(' {}: {}'.format(k, dconfig[k])) DatasetClass = load_dataset(config.dataset) logging.info('===> Initializing dataloader') setup_seed(2021) """ ---- Setting up train, val, test dataloaders ---- Supported datasets: - ScannetSparseVoxelizationDataset - ScannetDataset - SemanticKITTI """ point_scannet = False if config.is_train: if config.dataset == 'ScannetSparseVoxelizationDataset': point_scannet = False train_data_loader = initialize_data_loader( DatasetClass, config, phase=config.train_phase, threads=config.threads, augment_data=True, elastic_distortion=config.train_elastic_distortion, shuffle=True, # shuffle=False, # DEBUG ONLY!!! repeat=True, # repeat=False, batch_size=config.batch_size, limit_numpoints=config.train_limit_numpoints) val_data_loader = initialize_data_loader( DatasetClass, config, threads=config.val_threads, phase=config.val_phase, augment_data=False, elastic_distortion=config.test_elastic_distortion, shuffle=False, repeat=False, batch_size=config.val_batch_size, limit_numpoints=False) elif config.dataset == 'ScannetDataset': val_DatasetClass = load_dataset( 'ScannetDatasetWholeScene_evaluation') point_scannet = True # collate_fn = t.cfl_collate_fn_factory(False) # no limit num-points trainset = DatasetClass( root= '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles', npoints=config.num_points, # split='debug', split='train', with_norm=False, ) train_data_loader = torch.utils.data.DataLoader( dataset=trainset, num_workers=config.threads, # num_workers=0, # for loading big pth file, should use single-thread batch_size=config.batch_size, # collate_fn=collate_fn, # input points, should not have collate-fn worker_init_fn=_init_fn, sampler=InfSampler(trainset, True)) # shuffle=True valset = val_DatasetClass( root= '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles', scene_list_dir= '/data/eva_share_users/zhaotianchen/scannet/raw/metadata', # split='debug', split='eval', block_points=config.num_points, with_norm=False, delta=1.0, ) val_data_loader = torch.utils.data.DataLoader( dataset=valset, # num_workers=config.threads, num_workers= 0, # for loading big pth file, should use single-thread batch_size=config.val_batch_size, # collate_fn=collate_fn, # input points, should not have collate-fn worker_init_fn=_init_fn) elif config.dataset == "SemanticKITTI": point_scannet = False dataset = SemanticKITTI(root=config.semantic_kitti_path, num_points=None, voxel_size=config.voxel_size, sample_stride=config.sample_stride, submit=False) collate_fn_factory = t.cfl_collate_fn_factory train_data_loader = torch.utils.data.DataLoader( dataset['train'], batch_size=config.batch_size, sampler=InfSampler(dataset['train'], shuffle=True), # shuffle=true, repeat=true num_workers=config.threads, pin_memory=True, collate_fn=collate_fn_factory(config.train_limit_numpoints)) val_data_loader = torch.utils.data.DataLoader( # shuffle=false, repeat=false dataset['test'], batch_size=config.batch_size, num_workers=config.val_threads, pin_memory=True, collate_fn=t.cfl_collate_fn_factory(False)) elif config.dataset == "S3DIS": trainset = S3DIS( config, train=True, ) valset = S3DIS( config, train=False, ) train_data_loader = torch.utils.data.DataLoader( trainset, batch_size=config.batch_size, sampler=InfSampler(trainset, shuffle=True), # shuffle=true, repeat=true num_workers=config.threads, pin_memory=True, collate_fn=t.cfl_collate_fn_factory( config.train_limit_numpoints)) val_data_loader = torch.utils.data.DataLoader( # shuffle=false, repeat=false valset, batch_size=config.batch_size, num_workers=config.val_threads, pin_memory=True, collate_fn=t.cfl_collate_fn_factory(False)) elif config.dataset == 'Nuscenes': config.xyz_input = False # todo: trainset = Nuscenes( config, train=True, ) valset = Nuscenes( config, train=False, ) train_data_loader = torch.utils.data.DataLoader( trainset, batch_size=config.batch_size, sampler=InfSampler(trainset, shuffle=True), # shuffle=true, repeat=true num_workers=config.threads, pin_memory=True, # collate_fn=t.collate_fn_BEV, # used when cylinder voxelize collate_fn=t.cfl_collate_fn_factory(False)) val_data_loader = torch.utils.data.DataLoader( # shuffle=false, repeat=false valset, batch_size=config.batch_size, num_workers=config.val_threads, pin_memory=True, # collate_fn=t.collate_fn_BEV, collate_fn=t.cfl_collate_fn_factory(False)) else: print('Dataset {} not supported').format(config.dataset) raise NotImplementedError # Setting up num_in_channel and num_labels if train_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 num_labels = train_data_loader.dataset.NUM_LABELS # it = iter(train_data_loader) # for _ in range(100): # data = it.__next__() # print(data) else: # not config.is_train val_DatasetClass = load_dataset('ScannetDatasetWholeScene_evaluation') if config.dataset == 'ScannetSparseVoxelizationDataset': if config.is_export: # when export, we need to export the train results too train_data_loader = initialize_data_loader( DatasetClass, config, phase=config.train_phase, threads=config.threads, augment_data=True, elastic_distortion=config. train_elastic_distortion, # DEBUG: not sure about this shuffle=False, repeat=False, batch_size=config.batch_size, limit_numpoints=config.train_limit_numpoints) # the valid like, no aug data # train_data_loader = initialize_data_loader( # DatasetClass, # config, # threads=config.val_threads, # phase=config.train_phase, # augment_data=False, # elastic_distortion=config.test_elastic_distortion, # shuffle=False, # repeat=False, # batch_size=config.val_batch_size, # limit_numpoints=False) val_data_loader = initialize_data_loader( DatasetClass, config, threads=config.val_threads, phase=config.val_phase, augment_data=False, elastic_distortion=config.test_elastic_distortion, shuffle=False, repeat=False, batch_size=config.val_batch_size, limit_numpoints=False) if val_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = val_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 num_labels = val_data_loader.dataset.NUM_LABELS elif config.dataset == 'ScannetDataset': '''when using scannet-point, use val instead of test''' point_scannet = True valset = val_DatasetClass( root= '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles', scene_list_dir= '/data/eva_share_users/zhaotianchen/scannet/raw/metadata', split='eval', block_points=config.num_points, delta=1.0, with_norm=False, ) val_data_loader = torch.utils.data.DataLoader( dataset=valset, # num_workers=config.threads, num_workers= 0, # for loading big pth file, should use single-thread batch_size=config.val_batch_size, # collate_fn=collate_fn, # input points, should not have collate-fn worker_init_fn=_init_fn, ) num_labels = val_data_loader.dataset.NUM_LABELS num_in_channel = 3 elif config.dataset == "SemanticKITTI": dataset = SemanticKITTI(root=config.semantic_kitti_path, num_points=None, voxel_size=config.voxel_size, submit=config.submit) val_data_loader = torch.utils.data.DataLoader( # shuffle=false, repeat=false dataset['test'], batch_size=config.val_batch_size, num_workers=config.val_threads, pin_memory=True, collate_fn=t.cfl_collate_fn_factory(False)) num_in_channel = 4 num_labels = 19 elif config.dataset == 'S3DIS': config.xyz_input = False trainset = S3DIS( config, train=True, ) valset = S3DIS( config, train=False, ) train_data_loader = torch.utils.data.DataLoader( trainset, batch_size=config.batch_size, sampler=InfSampler(trainset, shuffle=True), # shuffle=true, repeat=true num_workers=config.threads, pin_memory=True, collate_fn=t.cfl_collate_fn_factory( config.train_limit_numpoints)) val_data_loader = torch.utils.data.DataLoader( # shuffle=false, repeat=false valset, batch_size=config.batch_size, num_workers=config.val_threads, pin_memory=True, collate_fn=t.cfl_collate_fn_factory(False)) num_in_channel = 9 num_labels = 13 elif config.dataset == 'Nuscenes': config.xyz_input = False trainset = Nuscenes( config, train=True, ) valset = Nuscenes( config, train - False, ) train_data_loader = torch.utils.data.DataLoader( trainset, batch_size=config.batch_size, sampler=InfSampler(trainset, shuffle=True), # shuffle=true, repeat=true num_workers=config.threads, pin_memory=True, # collate_fn=t.collate_fn_BEV, collate_fn=t.cfl_collate_fn_factory(False)) val_data_loader = torch.utils.data.DataLoader( # shuffle=false, repeat=false valset, batch_size=config.batch_size, num_workers=config.val_threads, pin_memory=True, # collate_fn=t.collate_fn_BEV, collate_fn=t.cfl_collate_fn_factory(False)) num_in_channel = 5 num_labels = 16 else: print('Dataset {} not supported').format(config.dataset) raise NotImplementedError logging.info('===> Building model') # if config.model == 'PointTransformer' or config.model == 'MixedTransformer': if config.model == 'PointTransformer': config.pure_point = True NetClass = load_model(config.model) if config.pure_point: model = NetClass(config, num_class=num_labels, N=config.num_points, normal_channel=num_in_channel) else: if config.model == 'MixedTransformer': model = NetClass(config, num_class=num_labels, N=config.num_points, normal_channel=num_in_channel) elif config.model == 'MinkowskiVoxelTransformer': model = NetClass(config, num_in_channel, num_labels) elif config.model == 'MinkowskiTransformerNet': model = NetClass(config, num_in_channel, num_labels) elif "Res" in config.model: model = NetClass(num_in_channel, num_labels, config) else: model = NetClass(num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}M'.format( NetClass.__name__, count_parameters(model) / 1e6)) if hasattr(model, "block1"): if hasattr(model.block1[0], 'h'): h = model.block1[0].h vec_dim = model.block1[0].vec_dim else: h = None vec_dim = None else: h = None vec_dim = None # logging.info('===> Model Args:\n PLANES: {} \n LAYERS: {}\n HEADS: {}\n Vec-dim: {}\n'.format(model.PLANES, model.LAYERS, h, vec_dim)) logging.info(model) # Set the number of threads # ME.initialize_nthreads(12, D=3) model = model.to(device) if config.weights == 'modelzoo': # Load modelzoo weights if possible. logging.info('===> Loading modelzoo weights') model.preload_modelzoo() # Load weights if specified by the parameter. elif config.weights.lower() != 'none': logging.info('===> Loading weights: ' + config.weights) state = torch.load(config.weights) # delete the keys containing the 'attn' since it raises size mismatch d_ = { k: v for k, v in state['state_dict'].items() if '_map' not in k } # debug: sometiems model conmtains 'map_qk' which is not right for naming a module, since 'map' are always buffers d = {} for k in d_.keys(): if 'module.' in k: d[k.replace('module.', '')] = d_[k] else: d[k] = d_[k] # del d_ if config.weights_for_inner_model: model.model.load_state_dict(d) else: if config.lenient_weight_loading: matched_weights = load_state_with_same_shape( model, state['state_dict']) model_dict = model.state_dict() model_dict.update(matched_weights) model.load_state_dict(model_dict) else: model.load_state_dict(d, strict=True) if config.is_debug: check_data(model, train_data_loader, val_data_loader, config) return None elif config.is_train: if hasattr(config, 'distill') and config.distill: assert point_scannet is not True # only support whole scene for no train_distill(model, train_data_loader, val_data_loader, config) if config.multiprocess: if point_scannet: raise NotImplementedError else: train_mp(NetClass, train_data_loader, val_data_loader, config) else: if point_scannet: train_point(model, train_data_loader, val_data_loader, config) else: train(model, train_data_loader, val_data_loader, config) elif config.is_export: if point_scannet: raise NotImplementedError else: # only support the whole-scene-style for now test(model, train_data_loader, config, save_pred=True, split='train') test(model, val_data_loader, config, save_pred=True, split='val') else: assert config.multiprocess == False # if test for submission, make a submit directory at current directory submit_dir = os.path.join(os.getcwd(), 'submit', 'sequences') if config.submit and not os.path.exists(submit_dir): os.makedirs(submit_dir) print("Made submission directory: " + submit_dir) if point_scannet: test_points(model, val_data_loader, config) else: test(model, val_data_loader, config, submit_dir=submit_dir)