def __init__(self, emb_dim=768, hid_size=32, layers=1, weights_mat=None, tr_labs=None, b_size=24, cp_dir='models/checkpoints/cim', lr=0.001, start_epoch=0, patience=3, step=1, gamma=0.75, n_eps=10, cim_type='cim', context='art'): self.start_epoch = start_epoch self.cp_dir = cp_dir self.device, self.use_cuda = get_torch_device() self.emb_dim = emb_dim self.hidden_size = hid_size self.batch_size = b_size if cim_type == 'cim': self.criterion = CrossEntropyLoss(weight=torch.tensor([.20, .80], device=self.device), reduction='sum') # could be made to depend on classweight which should be set on input else: self.criterion = CrossEntropyLoss(weight=torch.tensor([.25, .75], device=self.device), reduction='sum') # could be made to depend on classweight which should be set on input # self.criterion = NLLLoss(weight=torch.tensor([.15, .85], device=self.device)) # set criterion on input # n_pos = len([l for l in tr_labs if l == 1]) # class_weight = 1 - (n_pos / len(tr_labs)) # print(class_weight) # self.criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([.85], reduction='sum', dtype=torch.float, device=self.device)) if start_epoch > 0: self.model = self.load_model() else: self.model = ContextAwareModel(input_size=self.emb_dim, hidden_size=self.hidden_size, bilstm_layers=layers, weights_matrix=weights_mat, device=self.device, cam_type=cim_type, context=context) self.model = self.model.to(self.device) if self.use_cuda: self.model.cuda() # empty now and set during or after training self.train_time = 0 self.prev_val_f1 = 0 self.cp_name = None # depends on split type and current fold self.full_patience = patience self.current_patience = self.full_patience self.test_perf = [] self.test_perf_string = '' # set optimizer nr_train_instances = len(tr_labs) nr_train_batches = int(nr_train_instances / b_size) half_tr_bs = int(nr_train_instances/2) self.optimizer = AdamW(self.model.parameters(), lr=lr, eps=1e-8) # set scheduler if desired # self.scheduler = lr_scheduler.CyclicLR(self.optimizer, base_lr=lr, step_size_up=half_tr_bs, # cycle_momentum=False, max_lr=lr * 30) num_train_warmup_steps = int(0.1 * (nr_train_batches * n_eps)) # warmup_proportion
def main(): config = get_config() if config.resume: json_config = json.load(open(config.resume + '/config.json', 'r')) json_config['resume'] = config.resume config = edict(json_config) if config.is_cuda and not torch.cuda.is_available(): raise Exception("No GPU found") device = get_torch_device(config.is_cuda) logging.info('===> Configurations') dconfig = vars(config) for k in dconfig: logging.info(' {}: {}'.format(k, dconfig[k])) DatasetClass = load_dataset(config.dataset) if config.test_original_pointcloud: if not DatasetClass.IS_FULL_POINTCLOUD_EVAL: raise ValueError( 'This dataset does not support full pointcloud evaluation.') if config.evaluate_original_pointcloud: if not config.return_transformation: raise ValueError( 'Pointcloud evaluation requires config.return_transformation=true.' ) if (config.return_transformation ^ config.evaluate_original_pointcloud): raise ValueError( 'Rotation evaluation requires config.evaluate_original_pointcloud=true and ' 'config.return_transformation=true.') logging.info('===> Initializing dataloader') if config.is_train: train_data_loader = initialize_data_loader( DatasetClass, config, phase=config.train_phase, threads=config.threads, augment_data=True, shuffle=True, repeat=True, batch_size=config.batch_size, limit_numpoints=config.train_limit_numpoints) val_data_loader = initialize_data_loader( DatasetClass, config, threads=config.val_threads, phase=config.val_phase, augment_data=False, shuffle=True, repeat=False, batch_size=config.val_batch_size, limit_numpoints=False) if train_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 # RGB color num_labels = train_data_loader.dataset.NUM_LABELS else: test_data_loader = initialize_data_loader( DatasetClass, config, threads=config.threads, phase=config.test_phase, augment_data=False, shuffle=False, repeat=False, batch_size=config.test_batch_size, limit_numpoints=False) if test_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 # RGB color num_labels = test_data_loader.dataset.NUM_LABELS logging.info('===> Building model') NetClass = load_model(config.model) if config.wrapper_type == 'None': model = NetClass(num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}'.format( NetClass.__name__, count_parameters(model))) else: wrapper = load_wrapper(config.wrapper_type) model = wrapper(NetClass, num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}'.format( wrapper.__name__ + NetClass.__name__, count_parameters(model))) logging.info(model) model = model.to(device) if config.weights == 'modelzoo': # Load modelzoo weights if possible. logging.info('===> Loading modelzoo weights') model.preload_modelzoo() # Load weights if specified by the parameter. elif config.weights.lower() != 'none': logging.info('===> Loading weights: ' + config.weights) state = torch.load(config.weights) if config.weights_for_inner_model: model.model.load_state_dict(state['state_dict']) else: if config.lenient_weight_loading: matched_weights = load_state_with_same_shape( model, state['state_dict']) model_dict = model.state_dict() model_dict.update(matched_weights) model.load_state_dict(model_dict) else: model.load_state_dict(state['state_dict']) if config.is_train: train(model, train_data_loader, val_data_loader, config) else: test(model, test_data_loader, config)
def test(model, data_loader, config, transform_data_fn=None, has_gt=True): device = get_torch_device(config.is_cuda) dataset = data_loader.dataset num_labels = dataset.NUM_LABELS global_timer, data_timer, iter_timer = Timer(), Timer(), Timer() criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) losses, scores, ious = AverageMeter(), AverageMeter(), 0 aps = np.zeros((0, num_labels)) hist = np.zeros((num_labels, num_labels)) logging.info('===> Start testing') global_timer.tic() data_iter = data_loader.__iter__() max_iter = len(data_loader) max_iter_unique = max_iter # Fix batch normalization running mean and std model.eval() # Clear cache (when run in val mode, cleanup training cache) torch.cuda.empty_cache() if config.save_prediction or config.test_original_pointcloud: if config.save_prediction: save_pred_dir = config.save_pred_dir os.makedirs(save_pred_dir, exist_ok=True) else: save_pred_dir = tempfile.mkdtemp() if os.listdir(save_pred_dir): raise ValueError(f'Directory {save_pred_dir} not empty. ' 'Please remove the existing prediction.') with torch.no_grad(): for iteration in range(max_iter): data_timer.tic() if config.return_transformation: coords, input, target, transformation = data_iter.next() else: coords, input, target = data_iter.next() transformation = None data_time = data_timer.toc(False) # Preprocess input iter_timer.tic() if config.wrapper_type != 'None': color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) # Feed forward inputs = (sinput, ) if config.wrapper_type == 'None' else (sinput, coords, color) soutput = model(*inputs) output = soutput.F pred = get_prediction(dataset, output, target).int() iter_time = iter_timer.toc(False) if config.save_prediction or config.test_original_pointcloud: save_predictions(coords, pred, transformation, dataset, config, iteration, save_pred_dir) if has_gt: if config.evaluate_original_pointcloud: raise NotImplementedError('pointcloud') output, pred, target = permute_pointcloud( coords, pointcloud, transformation, dataset.label_map, output, pred) target_np = target.numpy() num_sample = target_np.shape[0] target = target.to(device) cross_ent = criterion(output, target.long()) losses.update(float(cross_ent), num_sample) scores.update(precision_at_one(pred, target), num_sample) hist += fast_hist(pred.cpu().numpy().flatten(), target_np.flatten(), num_labels) ious = per_class_iu(hist) * 100 prob = torch.nn.functional.softmax(output, dim=1) ap = average_precision(prob.cpu().detach().numpy(), target_np) aps = np.vstack((aps, ap)) # Due to heavy bias in class, there exists class with no test label at all with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) ap_class = np.nanmean(aps, 0) * 100. if iteration % config.test_stat_freq == 0 and iteration > 0: reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) class_names = dataset.get_classnames() print_info(iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if iteration % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() global_time = global_timer.toc(False) reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) class_names = dataset.get_classnames() print_info(iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if config.test_original_pointcloud: logging.info('===> Start testing on original pointcloud space.') dataset.test_pointcloud(save_pred_dir) logging.info("Finished test. Elapsed time: {:.4f}".format(global_time)) return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean( per_class_iu(hist)) * 100
EMB_DIM = 512 if EMB_TYPE == 'use' else 768 SAMPLER = args.sampler TASK_NAME = '_'.join([CONTEXT_TYPE, CIM_TYPE, EMB_BASE]) N_EPOCHS = args.epochs if not DEBUG else 5 PATIENCE = args.patience BATCH_SIZE = args.batch_size LR = args.learning_rate HIDDEN = args.hidden_size BILSTM_LAYERS = args.bilstm_layers NUM_LABELS = 2 MAX_DOC_LEN = 76 torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False device, USE_CUDA = get_torch_device() # DIRECTORIES DATA_DIR = f'data/inputs/cim/' DATA_FP = os.path.join(DATA_DIR, f'{TASK_NAME}_basil.tsv') CHECKPOINT_DIR = f'models/checkpoints/cim/{TASK_NAME}' PREDICTION_DIR = f'data/predictions/{TASK_NAME}/' REPORTS_DIR = f'reports/cim/{TASK_NAME}' TABLE_DIR = f"reports/cim/tables/{TASK_NAME}" MAIN_TABLE_FP = os.path.join(TABLE_DIR, f'{TASK_NAME}_results.csv') table_columns = 'model,sampler,seed,bs,lr,model_loc,fold,voter,epoch,set_type,loss,fn,fp,tn,tp,acc,prec,rec,f1' main_results_table = pd.DataFrame(columns=table_columns.split(',')) if not os.path.exists(CHECKPOINT_DIR): os.makedirs(CHECKPOINT_DIR) if not os.path.exists(REPORTS_DIR):
def test(model, data_loader, config, transform_data_fn=None, has_gt=True, validation=None, epoch=None): device = get_torch_device(config.is_cuda) dataset = data_loader.dataset num_labels = dataset.NUM_LABELS global_timer, data_timer, iter_timer = Timer(), Timer(), Timer() criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) alpha, gamma, eps = 1, 2, 1e-6 # Focal Loss parameters losses, scores, ious = AverageMeter(), AverageMeter(), 0 aps = np.zeros((0, num_labels)) hist = np.zeros((num_labels, num_labels)) if not config.is_train: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) model.load_state_dict(state['state_dict']) logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) if validation: logging.info('===> Start validating') else: logging.info('===> Start testing') global_timer.tic() data_iter = data_loader.__iter__() max_iter = len(data_loader) max_iter_unique = max_iter all_preds, all_labels, batch_losses, batch_loss = [], [], {}, 0 # Fix batch normalization running mean and std model.eval() # Clear cache (when run in val mode, cleanup training cache) torch.cuda.empty_cache() if config.save_prediction or config.test_original_pointcloud: if config.save_prediction: save_pred_dir = config.save_pred_dir os.makedirs(save_pred_dir, exist_ok=True) else: save_pred_dir = tempfile.mkdtemp() if os.listdir(save_pred_dir): raise ValueError(f'Directory {save_pred_dir} not empty. ' 'Please remove the existing prediction.') with torch.no_grad(): for iteration in range(max_iter): data_timer.tic() if config.return_transformation: coords, input, target, transformation = data_iter.next() else: coords, input, target = data_iter.next() transformation = None data_time = data_timer.toc(False) # Preprocess input iter_timer.tic() if config.wrapper_type != 'None': color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) # Feed forward inputs = (sinput, ) if config.wrapper_type == 'None' else (sinput, coords, color) soutput = model(*inputs) output = soutput.F pred = get_prediction(dataset, output, target).int() iter_time = iter_timer.toc(False) all_preds.append(pred.cpu().detach().numpy()) all_labels.append(target.cpu().detach().numpy()) if config.save_prediction or config.test_original_pointcloud: save_predictions(coords, pred, transformation, dataset, config, iteration, save_pred_dir) if has_gt: if config.evaluate_original_pointcloud: raise NotImplementedError('pointcloud') output, pred, target = permute_pointcloud( coords, pointcloud, transformation, dataset.label_map, output, pred) target_np = target.numpy() num_sample = target_np.shape[0] target = target.to(device) """# focal loss input_soft = nn.functional.softmax(output, dim=1) + eps focal_weight = torch.pow(-input_soft + 1., gamma) loss = (-alpha * focal_weight * torch.log(input_soft)).mean()""" loss = criterion(output, target.long()) batch_loss += loss losses.update(float(loss), num_sample) scores.update(precision_at_one(pred, target), num_sample) hist += fast_hist(pred.cpu().numpy().flatten(), target_np.flatten(), num_labels) ious = per_class_iu(hist) * 100 prob = torch.nn.functional.softmax(output, dim=1) ap = average_precision(prob.cpu().detach().numpy(), target_np) aps = np.vstack((aps, ap)) # Due to heavy bias in class, there exists class with no test label at all with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) ap_class = np.nanmean(aps, 0) * 100. if iteration % config.test_stat_freq == 0 and iteration > 0: preds = np.concatenate(all_preds) targets = np.concatenate(all_labels) to_ignore = [ i for i in range(len(targets)) if targets[i] == 255 ] preds_trunc = [ preds[i] for i in range(len(preds)) if i not in to_ignore ] targets_trunc = [ targets[i] for i in range(len(targets)) if i not in to_ignore ] cm = confusion_matrix(targets_trunc, preds_trunc, normalize='true') np.savetxt(config.log_dir + '/cm_epoch_{0}.txt'.format(epoch), cm) reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) class_names = dataset.get_classnames() print_info(iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if iteration % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() batch_losses[epoch] = batch_loss global_time = global_timer.toc(False) reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) class_names = dataset.get_classnames() print_info(iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if not config.is_train: preds = np.concatenate(all_preds) targets = np.concatenate(all_labels) to_ignore = [i for i in range(len(targets)) if targets[i] == 255] preds_trunc = [ preds[i] for i in range(len(preds)) if i not in to_ignore ] targets_trunc = [ targets[i] for i in range(len(targets)) if i not in to_ignore ] cm = confusion_matrix(targets_trunc, preds_trunc, normalize='true') np.savetxt(config.log_dir + '/cm.txt', cm) if config.test_original_pointcloud: logging.info('===> Start testing on original pointcloud space.') dataset.test_pointcloud(save_pred_dir) logging.info("Finished test. Elapsed time: {:.4f}".format(global_time)) if validation: loss_file_name = "/val_loss.txt" with open(config.log_dir + loss_file_name, 'a') as val_loss_file: for key in batch_losses: val_loss_file.writelines('{0}, {1}\n'.format( batch_losses[key], key)) val_loss_file.close() return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean( per_class_iu(hist)) * 100, batch_losses else: return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean( per_class_iu(hist)) * 100
def train_distill(model, data_loader, val_data_loader, config, transform_data_fn=None): ''' the distillation training some cfgs here ''' # distill_lambda = 1 # distill_lambda = 0.33 distill_lambda = 0.67 # TWO_STAGE=True: Transformer is first trained with L2 loss to match ResNet's activation, and then it fintunes like normal training on the second stage. # TWO_STAGE=False: Transformer trains with combined loss TWO_STAGE = False # STAGE_PERCENTAGE = 0.7 device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True # TODO: # load the sub-model only # FIXME: some dirty hard-written stuff, only supporting current state tch_model_cls = load_model('Res16UNet18A') tch_model = tch_model_cls(3, 20, config).to(device) # checkpoint_fn = "/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/outputs/ScannetSparseVoxelizationDataset/Res16UNet18A/resnet_base/weights.pth" checkpoint_fn = "/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/outputs/ScannetSparseVoxelizationDataset/Res16UNet18A/Res18A/weights.pth" # voxel-size: 0.05 assert osp.isfile(checkpoint_fn) logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) d = {k: v for k, v in state['state_dict'].items() if 'map' not in k} tch_model.load_state_dict(d) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) if config.resume: raise NotImplementedError # Test loaded ckpt first # checkpoint_fn = config.resume + '/weights.pth' # if osp.isfile(checkpoint_fn): # logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) # state = torch.load(checkpoint_fn) # curr_iter = state['iteration'] + 1 # epoch = state['epoch'] # d = {k:v for k,v in state['state_dict'].items() if 'map' not in k } # model.load_state_dict(d) # if config.resume_optimizer: # scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) # optimizer.load_state_dict(state['optimizer']) # if 'best_val' in state: # best_val_miou = state['best_val'] # best_val_iter = state['best_val_iter'] # logging.info("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_fn, state['epoch'])) # else: # raise ValueError("=> no checkpoint found at '{}'".format(checkpoint_fn)) # test after loading the ckpt v_loss, v_score, v_mAP, v_mIoU = test(tch_model, val_data_loader, config) logging.info('Tch model tested, bes_miou: {}'.format(v_mIoU)) data_iter = data_loader.__iter__() while is_training: num_class = 20 total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) total_iteration = len(data_loader) // config.iter_size for iteration in range(total_iteration): # NOTE: for single stage distillation, L2 loss might be too large at first # so we added a warmup training that don't use L2 loss if iteration < 0: use_distill = False else: use_distill = True # Stage 1 / Stage 2 boundary if TWO_STAGE: stage_boundary = int(total_iteration * STAGE_PERCENTAGE) optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): # Get training data data_timer.tic() if config.return_transformation: coords, input, target, _, _, pointcloud, transformation = data_iter.next( ) else: coords, input, target, _, _ = data_iter.next( ) # ignore unique_map and inverse_map if config.use_aux: assert target.shape[1] == 2 aux = target[:, 1] target = target[:, 0] else: aux = None # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5 # cat xyz into the rgb feature if config.xyz_input: input = torch.cat([coords_norm, input], dim=1) sinput = SparseTensor(input, coords, device=device) # TODO: return both-models # in order to not breaking the valid interface, use a get_loss to get the regsitered loss data_time += data_timer.toc(False) # model.initialize_coords(*init_args) if aux is not None: raise NotImplementedError # flatten ground truth tensor target = target.view(-1).long().to(device) if TWO_STAGE: if iteration < stage_boundary: # Stage 1: train transformer on L2 loss soutput, anchor = model(sinput, save_anchor=True) # Make sure gradient don't flow to teacher model with torch.no_grad(): _, tch_anchor = tch_model(sinput, save_anchor=True) loss = DistillLoss(tch_anchor, anchor) else: # Stage 2: finetune transformer on Cross-Entropy soutput = model(sinput) loss = criterion(soutput.F, target.long()) else: if use_distill: # after warm up soutput, anchor = model(sinput, save_anchor=True) # if pretrained teacher, do not let the grad flow to teacher to update its params with torch.no_grad(): tch_soutput, tch_anchor = tch_model( sinput, save_anchor=True) else: # warming up soutput = model(sinput) # The output of the network is not sorted loss = criterion(soutput.F, target.long()) # Add L2 loss if use distillation if use_distill: distill_loss = DistillLoss(tch_anchor, anchor) * distill_lambda loss += distill_loss # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # Update number of steps optimizer.step() scheduler.step() # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target, ignore_label=-1) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # calc the train-iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target != -1)) | (target == l)).sum() if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "[{}] ===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( config.log_dir, epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) if use_distill and not TWO_STAGE: logging.info('Loss {} Distill Loss:{}'.format( loss, distill_loss)) # Reset timers data_time_avg.reset() iter_time_avg.reset() losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, None, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val", save_inter=True) logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) v_loss, v_score, v_mAP, val_miou = test(model, val_data_loader, config) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() regs, losses, scores = AverageMeter(), AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: # Test loaded ckpt first v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config) checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] # we skip attention maps because the shape won't match because voxel number is different # e.g. copyting a param with shape (23385, 8, 4) to (43529, 8, 4) d = { k: v for k, v in state['state_dict'].items() if 'map' not in k } # handle those attn maps we don't load from saved dict for k in model.state_dict().keys(): if k in d.keys(): continue d[k] = model.state_dict()[k] model.load_state_dict(d) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() if config.dataset == "SemanticKITTI": num_class = 19 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 10 elif config.dataset == "S3DIS": num_class = 13 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq elif config.dataset == "Nuscenes": num_class = 16 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 50 else: num_class = 20 val_freq_ = config.val_freq while is_training: total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() if curr_iter >= config.max_iter: # if curr_iter >= max(config.max_iter, config.epochs*(len(data_loader) // config.iter_size): is_training = False break elif curr_iter >= config.max_iter * (2 / 3): config.val_freq = val_freq_ * 2 # valid more freq on lower half for sub_iter in range(config.iter_size): # Get training data data_timer.tic() pointcloud = None if config.return_transformation: coords, input, target, _, _, pointcloud, transformation, _ = data_iter.next( ) else: coords, input, target, _, _, _ = data_iter.next( ) # ignore unique_map and inverse_map if config.use_aux: assert target.shape[1] == 2 aux = target[:, 1] target = target[:, 0] else: aux = None # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / input[:, :3].max() - 0.5 coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5 # cat xyz into the rgb feature if config.xyz_input: input = torch.cat([coords_norm, input], dim=1) sinput = SparseTensor(input, coords, device=device) starget = SparseTensor( target.unsqueeze(-1).float(), coordinate_map_key=sinput.coordinate_map_key, coordinate_manager=sinput.coordinate_manager, device=device ) # must share the same coord-manager to align for sinput data_time += data_timer.toc(False) # model.initialize_coords(*init_args) # d = {} # d['c'] = sinput.C # d['l'] = starget.F # torch.save('./plot/test-label.pth') # import ipdb; ipdb.set_trace() # Set up profiler # memory_profiler = CUDAMemoryProfiler( # [model, criterion], # filename="cuda_memory.profile" # ) # sys.settrace(memory_profiler) # threading.settrace(memory_profiler) # with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=True) as prof0: if aux is not None: soutput = model(sinput, aux) elif config.enable_point_branch: soutput = model(sinput, iter_=curr_iter / config.max_iter, enable_point_branch=True) else: # label-aux, feed it in as additional reg soutput = model( sinput, iter_=curr_iter / config.max_iter, aux=starget ) # feed in the progress of training for annealing inside the model # The output of the network is not sorted target = target.view(-1).long().to(device) loss = criterion(soutput.F, target.long()) # ====== other loss regs ===== if hasattr(model, 'block1'): cur_loss = torch.tensor([0.], device=device) if hasattr(model.block1[0], 'vq_loss'): if model.block1[0].vq_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].vq_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur vq_loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.block1[0], 'diverse_loss'): if model.block1[0].diverse_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].diverse_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur diverse _loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.block1[0], 'label_reg'): if model.block1[0].label_reg is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].label_reg # m is the nn.Sequential obj, m[0] is the TRBlock # logging.info('Cur Loss: {}, Cur diverse _loss: {}'.format(loss, cur_loss)) loss += cur_loss # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # soutput = model(sinput) # Update number of steps if not config.use_sam: optimizer.step() else: optimizer.first_step(zero_grad=True) soutput = model(sinput, iter_=curr_iter / config.max_iter, aux=starget) criterion(soutput.F, target.long()).backward() optimizer.second_step(zero_grad=True) if config.lr_warmup is None: scheduler.step() else: if curr_iter >= config.lr_warmup: scheduler.step() for g in optimizer.param_groups: g['lr'] = config.lr * (iteration + 1) / config.lr_warmup # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target, ignore_label=-1) regs.update(cur_loss.item(), target.size(0)) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # calc the train-iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target != -1)) | (target == l)).sum() if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) IoU = ((total_correct_class) / (total_iou_deno_class + 1e-6)).mean() * 100. debug_str = "[{}] ===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( config.log_dir.split('/')[-2], epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tIoU {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, IoU.item(), data_time_avg.avg, iter_time_avg.avg) if regs.avg > 0: debug_str += "\n Additional Reg Loss {:.3f}".format( regs.avg) # print(debug_str) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, None, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val", save_inter=True) logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # print("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) v_loss, v_score, v_mAP, val_miou = test(model, val_data_loader, config) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train_point(model, data_loader, val_data_loader, config, transform_data_fn=None): device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=-1) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] d = { k: v for k, v in state['state_dict'].items() if 'map' not in k } model.load_state_dict(d) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() while is_training: num_class = 20 total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): # Get training data data = data_iter.next() points, target, sample_weight = data if config.pure_point: sinput = points.transpose(1, 2).cuda().float() # DEBUG: use the discrete coord for point-based ''' feats = torch.unbind(points[:,:,:], dim=0) voxel_size = config.voxel_size coords = torch.unbind(points[:,:,:3]/voxel_size, dim=0) # 0.05 is the voxel-size coords, feats= ME.utils.sparse_collate(coords, feats) # assert feats.reshape([16, 4096, -1]) == points[:,:,3:] points_ = ME.TensorField(features=feats.float(), coordinates=coords, device=device) tmp_voxel = points_.sparse() sinput_ = tmp_voxel.slice(points_) sinput = torch.cat([sinput_.C[:,1:]*config.voxel_size, sinput_.F[:,3:]],dim=1).reshape([config.batch_size, config.num_points, 6]) # sinput = sinput_.F.reshape([config.batch_size, config.num_points, 6]) sinput = sinput.transpose(1,2).cuda().float() # sinput = torch.cat([coords[:,1:], feats],dim=1).reshape([config.batch_size, config.num_points, 6]) # sinput = sinput.transpose(1,2).cuda().float() ''' # For some networks, making the network invariant to even, odd coords is important # coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input # if config.normalize_color: # feats = feats / 255. - 0.5 # torch.save(points[:,:,:3], './sandbox/tensorfield-c.pth') # torch.save(points_.C, './sandbox/points-c.pth') else: # feats = torch.unbind(points[:,:,3:], dim=0) # WRONG: should also feed in xyz as inupt feature voxel_size = config.voxel_size coords = torch.unbind(points[:, :, :3] / voxel_size, dim=0) # 0.05 is the voxel-size # Normalize the xyz in feature # points[:,:,:3] = points[:,:,:3] / points[:,:,:3].mean() feats = torch.unbind(points[:, :, :], dim=0) coords, feats = ME.utils.sparse_collate(coords, feats) # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input # if config.normalize_color: # feats = feats / 255. - 0.5 # they are the same points_ = ME.TensorField(features=feats.float(), coordinates=coords, device=device) # points_1 = ME.TensorField(features=feats.float(), coordinates=coords, device=device, quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE) # points_2 = ME.TensorField(features=feats.float(), coordinates=coords, device=device, quantization_mode=ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE) sinput = points_.sparse() data_time += data_timer.toc(False) B, npoint = target.shape # model.initialize_coords(*init_args) soutput = model(sinput) if config.pure_point: soutput = soutput.reshape([B * npoint, -1]) else: soutput = soutput.slice(points_).F # s1 = soutput.slice(points_) # print(soutput.quantization_mode) # soutput.quantization_mode = ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE # s2 = soutput.slice(points_) # The output of the network is not sorted target = (target - 1).view(-1).long().to(device) # catch NAN if torch.isnan(soutput).sum() > 0: import ipdb ipdb.set_trace() loss = criterion(soutput, target) if torch.isnan(loss).sum() > 0: import ipdb ipdb.set_trace() loss = (loss * sample_weight.to(device)).mean() # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # print(model.input_mlp[0].weight.max()) # print(model.input_mlp[0].weight.grad.max()) # Update number of steps optimizer.step() scheduler.step() # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput, target) score = precision_at_one(pred, target, ignore_label=-1) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # Calc the iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target >= 0)) | (target == l)).sum() if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation: # for point-based should use alternate dataloader for eval # if curr_iter % config.val_freq == 0: # val_miou = test_points(model, val_data_loader, None, curr_iter, config, transform_data_fn) # if val_miou > best_val_miou: # best_val_miou = val_miou # best_val_iter = curr_iter # checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, # "best_val") # logging.info("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter)) # # Recover back # model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) test_points(model, val_data_loader, config) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): all_losses = [] device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration writer = SummaryWriter(log_dir=config.log_dir) data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores, batch_losses = AverageMeter(), AverageMeter(), {} optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) alpha, gamma, eps = 1, 2, 1e-6 writer = SummaryWriter(log_dir=config.log_dir) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] model.load_state_dict(state['state_dict']) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() while is_training: print( "********************************** epoch N° {0} ************************" .format(epoch)) for iteration in range(len(data_loader) // config.iter_size): print("####### Iteration N° {0}".format(iteration)) optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): print("------------- Sub_iteration N° {0}".format(sub_iter)) # Get training data data_timer.tic() coords, input, target = data_iter.next() print("len of coords : {0}".format(len(coords))) # For some networks, making the network invariant to even, odd coords is important coords[:, :3] += (torch.rand(3) * 100).type_as(coords) # Preprocess input color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) data_time += data_timer.toc(False) # Feed forward inputs = (sinput, ) if config.wrapper_type == 'None' else ( sinput, coords, color) # model.initialize_coords(*init_args) soutput = model(*inputs) # The output of the network is not sorted target = target.long().to(device) print("count of classes : {0}".format( np.unique(target.cpu().numpy(), return_counts=True))) print("target : {0}\ntarget_len : {1}".format( target, len(target))) print("target [0]: {0}".format(target[0])) input_soft = nn.functional.softmax(soutput.F, dim=1) + eps print("input_soft[0] : {0}".format(input_soft[0])) focal_weight = torch.pow(-input_soft + 1., gamma) print("focal_weight : {0}\nweight[0] : {1}".format( focal_weight, focal_weight[0])) focal_loss = (-alpha * focal_weight * torch.log(input_soft)).mean() loss = criterion(soutput.F, target.long()) print("focal_loss :{0}\nloss : {1}".format(focal_loss, loss)) # Compute and accumulate gradient loss /= config.iter_size #batch_loss += loss batch_loss += loss.item() print("batch_loss : {0}".format(batch_loss)) loss.backward() # Update number of steps optimizer.step() scheduler.step() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Total iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs writer.add_scalar('training/loss', losses.avg, curr_iter) writer.add_scalar('training/precision_at_1', scores.avg, curr_iter) writer.add_scalar('training/learning_rate', scheduler.get_lr()[0], curr_iter) losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) # Validation if curr_iter % config.val_freq == 0: val_miou, val_losses = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn, epoch) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() if curr_iter % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() batch_losses[epoch] = batch_loss # End of iteration curr_iter += 1 with open(config.log_dir + "/train_loss.txt", 'a') as train_loss_log: train_loss_log.writelines('{0}, {1}\n'.format( batch_losses[epoch], epoch)) train_loss_log.close() epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn, epoch)[0] if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration writer = SummaryWriter(log_dir=config.log_dir) data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) writer = SummaryWriter(log_dir=config.log_dir) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] model.load_state_dict(state['state_dict']) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_fn, state['epoch'])) else: raise ValueError("=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() while is_training: for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() try: # torch issue #16998 for sub_iter in range(config.iter_size): # Get training data data_timer.tic() if config.return_transformation: coords, input, target, pointcloud, transformation = data_iter.next() else: coords, input, target = data_iter.next() # For some networks, making the network invariant to even, odd coords is important coords[:, :3] += (torch.rand(3) * 100).type_as(coords) # Preprocess input color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) data_time += data_timer.toc(False) # Feed forward inputs = (sinput,) if config.wrapper_type == 'None' else (sinput, coords, color) # model.initialize_coords(*init_args) soutput = model(*inputs) # The output of the network is not sorted target = target.long().to(device) loss = criterion(soutput.F, target.long()) # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() except Exception as e: logging.error(e) continue # Update number of steps optimizer.step() scheduler.step() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join(['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs writer.add_scalar('training/loss', losses.avg, curr_iter) writer.add_scalar('training/precision_at_1', scores.avg, curr_iter) writer.add_scalar('training/learning_rate', scheduler.get_lr()[0], curr_iter) losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter)) # Recover back model.train() if curr_iter % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() # End of iteration curr_iter += 1 epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter))
def test(model, data_loader, config, transform_data_fn=None, has_gt=True, save_pred=False, split=None, submit_dir=None): device = get_torch_device(config.is_cuda) dataset = data_loader.dataset num_labels = dataset.NUM_LABELS global_timer, data_timer, iter_timer = Timer(), Timer(), Timer() criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) losses, scores, ious = AverageMeter(), AverageMeter(), 0 aps = np.zeros((0, num_labels)) hist = np.zeros((num_labels, num_labels)) # some cfgs concerning the usage of instance-level information config.save_pred = save_pred if split is not None: assert save_pred if config.save_pred: save_dict = {} save_dict['pred'] = [] save_dict['coord'] = [] logging.info('===> Start testing') global_timer.tic() data_iter = data_loader.__iter__() max_iter = len(data_loader) max_iter_unique = max_iter # Fix batch normalization running mean and std model.eval() # Clear cache (when run in val mode, cleanup training cache) torch.cuda.empty_cache() # semantic kitti label inverse mapping if config.submit: remap_lut = Remap().getRemapLUT() with torch.no_grad(): # Calc of the iou total_correct = np.zeros(num_labels) total_seen = np.zeros(num_labels) total_positive = np.zeros(num_labels) point_nums = np.zeros([19]) for iteration in range(max_iter): data_timer.tic() if config.return_transformation: coords, input, target, unique_map_list, inverse_map_list, pointcloud, transformation, filename = data_iter.next() else: coords, input, target, unique_map_list, inverse_map_list, filename = data_iter.next() data_time = data_timer.toc(False) if config.use_aux: assert target.shape[1] == 2 aux = target[:,1] target = target[:,0] else: aux = None # Preprocess input iter_timer.tic() if config.normalize_color: input[:, :3] = input[:, :3] / input[:,:3].max() - 0.5 coords_norm = coords[:,1:] / coords[:,1:].max() - 0.5 XYZ_INPUT = config.xyz_input # cat xyz into the rgb feature if XYZ_INPUT: input = torch.cat([coords_norm, input], dim=1) sinput = ME.SparseTensor(input, coords, device=device) # Feed forward if aux is not None: soutput = model(sinput) else: soutput = model(sinput, iter_ = iteration / max_iter, enable_point_branch=config.enable_point_branch) output = soutput.F if torch.isnan(output).sum() > 0: import ipdb; ipdb.set_trace() pred = get_prediction(dataset, output, target).int() assert sum([int(t.shape[0]) for t in unique_map_list]) == len(pred), "number of points in unique_map doesn't match predition, do not enable preprocessing" iter_time = iter_timer.toc(False) if config.save_pred or config.submit: # troublesome processing for splitting each batch's data, and export batch_ids = sinput.C[:,0] splits_at = torch.stack([torch.where(batch_ids == i)[0][-1] for i in torch.unique(batch_ids)]).int() splits_at = splits_at + 1 splits_at_leftshift_one = splits_at.roll(shifts=1) splits_at_leftshift_one[0] = 0 # len_per_batch = splits_at - splits_at_leftshift_one len_sum = 0 batch_id = 0 for start, end in zip(splits_at_leftshift_one, splits_at): len_sum += len(pred[int(start):int(end)]) pred_this_batch = pred[int(start):int(end)] coord_this_batch = pred[int(start):int(end)] if config.save_pred: save_dict['pred'].append(pred_this_batch[inverse_map_list[batch_id]]) else: # save submit result submission_path = filename[batch_id].replace(config.semantic_kitti_path, submit_dir).replace('velodyne', 'predictions').replace('.bin', '.label') parent_dir = Path(submission_path).parent.absolute() if not os.path.exists(parent_dir): os.makedirs(parent_dir) label_pred = pred_this_batch[inverse_map_list[batch_id]].cpu().numpy() label_pred = remap_lut[label_pred].astype(np.uint32) label_pred.tofile(submission_path) print(submission_path) batch_id += 1 assert len_sum == len(pred) # Unpack it to original length REVERT_WHOLE_POINTCLOUD = True # print('{}/{}'.format(iteration, max_iter)) if REVERT_WHOLE_POINTCLOUD: whole_pred = [] whole_target = [] for batch_ in range(config.batch_size): batch_mask_ = (soutput.C[:,0] == batch_).cpu().numpy() if batch_mask_.sum() == 0: # for empty batch, skip em continue try: whole_pred_ = soutput.F[batch_mask_][inverse_map_list[batch_]] except: import ipdb; ipdb.set_trace() whole_target_ = target[batch_mask_][inverse_map_list[batch_]] whole_pred.append(whole_pred_) whole_target.append(whole_target_) whole_pred = torch.cat(whole_pred, dim=0) whole_target = torch.cat(whole_target, dim=0) pred = get_prediction(dataset, whole_pred, whole_target).int() output = whole_pred target = whole_target if has_gt: target_np = target.numpy() num_sample = target_np.shape[0] target = target.to(device) output = output.to(device) cross_ent = criterion(output, target.long()) losses.update(float(cross_ent), num_sample) scores.update(precision_at_one(pred, target), num_sample) hist += fast_hist(pred.cpu().numpy().flatten(), target_np.flatten(), num_labels) # within fast hist, mark label should >=0 & < num_label to filter out 255 / -1 ious = per_class_iu(hist) * 100 prob = torch.nn.functional.softmax(output, dim=-1) pred = pred[target != -1] target = target[target != -1] # for _ in range(num_labels): # debug for SemKITTI: spvnas way of calc miou # total_seen[_] += torch.sum(target == _) # total_correct[_] += torch.sum((pred == target) & (target == _)) # total_positive[_] += torch.sum(pred == _) # ious_ = [] # for _ in range(num_labels): # if total_seen[_] == 0: # ious_.append(1) # else: # ious_.append(total_correct[_]/(total_seen[_] + total_positive[_] - total_correct[_])) # ious_ = torch.stack(ious_, dim=-1).cpu().numpy()*100 # print(np.nanmean(per_class_iu(hist)), np.nanmean(ious_)) # ious = np.array(ious_)*100 # calc the ratio of total points # for i_ in range(19): # point_nums[i_] += (target == i_).sum().detach() # skip calculating aps ap = average_precision(prob.cpu().detach().numpy(), target_np) aps = np.vstack((aps, ap)) # Due to heavy bias in class, there exists class with no test label at all with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) ap_class = np.nanmean(aps, 0) * 100. if iteration % config.test_stat_freq == 0 and iteration > 0 and not config.submit: reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) # dirty fix for semnaticcKITTI has no getclassnames if hasattr(dataset, "class_names"): class_names = dataset.get_classnames() else: # semnantic KITTI class_names = None print_info( iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if iteration % 5 == 0: # Clear cache torch.cuda.empty_cache() if config.save_pred: # torch.save(save_dict, os.path.join(config.log_dir, 'preds_{}_with_coord.pth'.format(split))) torch.save(save_dict, os.path.join(config.log_dir, 'preds_{}.pth'.format(split))) print("===> saved prediction result") global_time = global_timer.toc(False) save_map(model, config) reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) if hasattr(dataset, "class_names"): class_names = dataset.get_classnames() else: class_names = None print_info( iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) logging.info("Finished test. Elapsed time: {:.4f}".format(global_time)) # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean(per_class_iu(hist)) * 100
def main_worker(gpu, ngpus_per_node, config): config.gpu = gpu #if config.is_cuda and not torch.cuda.is_available(): # raise Exception("No GPU found") if config.gpu is not None: print("Use GPU: {} for training".format(config.gpu)) device = get_torch_device(config.is_cuda) if config.distributed: if config.dist_url == "env://" and config.rank == -1: config.rank = int(os.environ["RANK"]) if config.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes config.rank = config.rank * ngpus_per_node + gpu dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url, world_size=config.world_size, rank=config.rank) logging.info('===> Configurations') dconfig = vars(config) for k in dconfig: logging.info(' {}: {}'.format(k, dconfig[k])) DatasetClass = load_dataset(config.dataset) if config.test_original_pointcloud: if not DatasetClass.IS_FULL_POINTCLOUD_EVAL: raise ValueError( 'This dataset does not support full pointcloud evaluation.') if config.evaluate_original_pointcloud: if not config.return_transformation: raise ValueError( 'Pointcloud evaluation requires config.return_transformation=true.' ) if (config.return_transformation ^ config.evaluate_original_pointcloud): raise ValueError( 'Rotation evaluation requires config.evaluate_original_pointcloud=true and ' 'config.return_transformation=true.') logging.info('===> Initializing dataloader') if config.is_train: train_data_loader, train_sampler = initialize_data_loader( DatasetClass, config, phase=config.train_phase, num_workers=config.num_workers, augment_data=True, shuffle=True, repeat=True, batch_size=config.batch_size, limit_numpoints=config.train_limit_numpoints) val_data_loader, val_sampler = initialize_data_loader( DatasetClass, config, num_workers=config.num_val_workers, phase=config.val_phase, augment_data=False, shuffle=True, repeat=False, batch_size=config.val_batch_size, limit_numpoints=False) if train_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 # RGB color num_labels = train_data_loader.dataset.NUM_LABELS else: test_data_loader, val_sampler = initialize_data_loader( DatasetClass, config, num_workers=config.num_workers, phase=config.test_phase, augment_data=False, shuffle=False, repeat=False, batch_size=config.test_batch_size, limit_numpoints=False) if test_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 # RGB color num_labels = test_data_loader.dataset.NUM_LABELS logging.info('===> Building model') NetClass = load_model(config.model) if config.wrapper_type == 'None': model = NetClass(num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}'.format( NetClass.__name__, count_parameters(model))) else: wrapper = load_wrapper(config.wrapper_type) model = wrapper(NetClass, num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}'.format( wrapper.__name__ + NetClass.__name__, count_parameters(model))) logging.info(model) if config.weights == 'modelzoo': # Load modelzoo weights if possible. logging.info('===> Loading modelzoo weights') model.preload_modelzoo() # Load weights if specified by the parameter. elif config.weights.lower() != 'none': logging.info('===> Loading weights: ' + config.weights) state = torch.load(config.weights) if config.weights_for_inner_model: model.model.load_state_dict(state['state_dict']) else: if config.lenient_weight_loading: matched_weights = load_state_with_same_shape( model, state['state_dict']) model_dict = model.state_dict() model_dict.update(matched_weights) model.load_state_dict(model_dict) else: init_model_from_weights(model, state, freeze_bb=False) if config.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if config.gpu is not None: torch.cuda.set_device(config.gpu) model.cuda(config.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have config.batch_size = int(config.batch_size / ngpus_per_node) config.num_workers = int( (config.num_workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[config.gpu]) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) if config.is_train: train(model, train_data_loader, val_data_loader, config, train_sampler=train_sampler, ngpus_per_node=ngpus_per_node) else: test(model, test_data_loader, config)
def main(): config = get_config() if config.test_config: json_config = json.load(open(config.test_config, 'r')) json_config['is_train'] = False json_config['weights'] = config.weights config = edict(json_config) elif config.resume: json_config = json.load(open(config.resume + '/config.json', 'r')) json_config['resume'] = config.resume config = edict(json_config) if config.is_cuda and not torch.cuda.is_available(): raise Exception("No GPU found") device = get_torch_device(config.is_cuda) # torch.set_num_threads(config.threads) # torch.manual_seed(config.seed) # if config.is_cuda: # torch.cuda.manual_seed(config.seed) logging.info('===> Configurations') dconfig = vars(config) for k in dconfig: logging.info(' {}: {}'.format(k, dconfig[k])) DatasetClass = load_dataset(config.dataset) logging.info('===> Initializing dataloader') if config.is_train: setup_seed(2021) train_data_loader = initialize_data_loader( DatasetClass, config, phase=config.train_phase, # threads=config.threads, threads=4, augment_data=True, elastic_distortion=config.train_elastic_distortion, # elastic_distortion=False, # shuffle=True, shuffle=False, # repeat=True, repeat=False, batch_size=config.batch_size, # batch_size=8, limit_numpoints=config.train_limit_numpoints) # dat = iter(train_data_loader).__next__() # import ipdb; ipdb.set_trace() val_data_loader = initialize_data_loader( DatasetClass, config, # threads=0, threads=config.val_threads, phase=config.val_phase, augment_data=False, elastic_distortion=config.test_elastic_distortion, shuffle=False, repeat=False, # batch_size=config.val_batch_size, batch_size=8, limit_numpoints=False) # dat = iter(val_data_loader).__next__() # import ipdb; ipdb.set_trace() if train_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 num_labels = train_data_loader.dataset.NUM_LABELS else: test_data_loader = initialize_data_loader( DatasetClass, config, threads=config.threads, phase=config.test_phase, augment_data=False, elastic_distortion=config.test_elastic_distortion, shuffle=False, repeat=False, batch_size=config.test_batch_size, limit_numpoints=False) if test_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 num_labels = test_data_loader.dataset.NUM_LABELS logging.info('===> Building model') NetClass = load_model(config.model) model = NetClass(num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}'.format( NetClass.__name__, count_parameters(model))) logging.info(model) # Set the number of threads # ME.initialize_nthreads(12, D=3) model = model.to(device) if config.weights == 'modelzoo': # Load modelzoo weights if possible. logging.info('===> Loading modelzoo weights') model.preload_modelzoo() # Load weights if specified by the parameter. elif config.weights.lower() != 'none': logging.info('===> Loading weights: ' + config.weights) state = torch.load(config.weights) if config.weights_for_inner_model: model.model.load_state_dict(state['state_dict']) else: if config.lenient_weight_loading: matched_weights = load_state_with_same_shape( model, state['state_dict']) model_dict = model.state_dict() model_dict.update(matched_weights) model.load_state_dict(model_dict) else: model.load_state_dict(state['state_dict']) if config.is_train: train(model, train_data_loader, val_data_loader, config) else: test(model, test_data_loader, config)
def main(): config = get_config() ch = logging.StreamHandler(sys.stdout) logging.getLogger().setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler( os.path.join(config.log_dir, './model.log')) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logging.basicConfig(format=os.uname()[1].split('.')[0] + ' %(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch, file_handler]) if config.test_config: # When using the test_config, reload and overwrite it, so should keep some configs val_bs = config.val_batch_size is_export = config.is_export json_config = json.load(open(config.test_config, 'r')) json_config['is_train'] = False json_config['weights'] = config.weights json_config['multiprocess'] = False json_config['log_dir'] = config.log_dir json_config['val_threads'] = config.val_threads json_config['submit'] = config.submit config = edict(json_config) config.val_batch_size = val_bs config.is_export = is_export config.is_train = False sys.path.append(config.log_dir) # from local_models import load_model else: '''bakup files''' if not os.path.exists(os.path.join(config.log_dir, 'models')): os.mkdir(os.path.join(config.log_dir, 'models')) for filename in os.listdir('./models'): if ".py" in filename: # donnot cp the init file since it will raise import error shutil.copy(os.path.join("./models", filename), os.path.join(config.log_dir, 'models')) elif 'modules' in filename: # copy the moduls folder also if os.path.exists( os.path.join(config.log_dir, 'models/modules')): shutil.rmtree( os.path.join(config.log_dir, 'models/modules')) shutil.copytree(os.path.join('./models', filename), os.path.join(config.log_dir, 'models/modules')) shutil.copy('./main.py', config.log_dir) shutil.copy('./config.py', config.log_dir) shutil.copy('./lib/train.py', config.log_dir) shutil.copy('./lib/test.py', config.log_dir) if config.resume == 'True': new_iter_size = config.max_iter new_bs = config.batch_size config.resume = config.log_dir json_config = json.load(open(config.resume + '/config.json', 'r')) json_config['resume'] = config.resume config = edict(json_config) config.weights = os.path.join( config.log_dir, 'weights.pth') # use the pre-trained weights logging.info('==== resuming from {}, Total {} ======'.format( config.max_iter, new_iter_size)) config.max_iter = new_iter_size config.batch_size = new_bs else: config.resume = None if config.is_cuda and not torch.cuda.is_available(): raise Exception("No GPU found") gpu_list = range(config.num_gpu) device = get_torch_device(config.is_cuda) # torch.set_num_threads(config.threads) # torch.manual_seed(config.seed) # if config.is_cuda: # torch.cuda.manual_seed(config.seed) logging.info('===> Configurations') dconfig = vars(config) for k in dconfig: logging.info(' {}: {}'.format(k, dconfig[k])) DatasetClass = load_dataset(config.dataset) logging.info('===> Initializing dataloader') setup_seed(2021) """ ---- Setting up train, val, test dataloaders ---- Supported datasets: - ScannetSparseVoxelizationDataset - ScannetDataset - SemanticKITTI """ point_scannet = False if config.is_train: if config.dataset == 'ScannetSparseVoxelizationDataset': point_scannet = False train_data_loader = initialize_data_loader( DatasetClass, config, phase=config.train_phase, threads=config.threads, augment_data=True, elastic_distortion=config.train_elastic_distortion, shuffle=True, # shuffle=False, # DEBUG ONLY!!! repeat=True, # repeat=False, batch_size=config.batch_size, limit_numpoints=config.train_limit_numpoints) val_data_loader = initialize_data_loader( DatasetClass, config, threads=config.val_threads, phase=config.val_phase, augment_data=False, elastic_distortion=config.test_elastic_distortion, shuffle=False, repeat=False, batch_size=config.val_batch_size, limit_numpoints=False) elif config.dataset == 'ScannetDataset': val_DatasetClass = load_dataset( 'ScannetDatasetWholeScene_evaluation') point_scannet = True # collate_fn = t.cfl_collate_fn_factory(False) # no limit num-points trainset = DatasetClass( root= '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles', npoints=config.num_points, # split='debug', split='train', with_norm=False, ) train_data_loader = torch.utils.data.DataLoader( dataset=trainset, num_workers=config.threads, # num_workers=0, # for loading big pth file, should use single-thread batch_size=config.batch_size, # collate_fn=collate_fn, # input points, should not have collate-fn worker_init_fn=_init_fn, sampler=InfSampler(trainset, True)) # shuffle=True valset = val_DatasetClass( root= '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles', scene_list_dir= '/data/eva_share_users/zhaotianchen/scannet/raw/metadata', # split='debug', split='eval', block_points=config.num_points, with_norm=False, delta=1.0, ) val_data_loader = torch.utils.data.DataLoader( dataset=valset, # num_workers=config.threads, num_workers= 0, # for loading big pth file, should use single-thread batch_size=config.val_batch_size, # collate_fn=collate_fn, # input points, should not have collate-fn worker_init_fn=_init_fn) elif config.dataset == "SemanticKITTI": point_scannet = False dataset = SemanticKITTI(root=config.semantic_kitti_path, num_points=None, voxel_size=config.voxel_size, sample_stride=config.sample_stride, submit=False) collate_fn_factory = t.cfl_collate_fn_factory train_data_loader = torch.utils.data.DataLoader( dataset['train'], batch_size=config.batch_size, sampler=InfSampler(dataset['train'], shuffle=True), # shuffle=true, repeat=true num_workers=config.threads, pin_memory=True, collate_fn=collate_fn_factory(config.train_limit_numpoints)) val_data_loader = torch.utils.data.DataLoader( # shuffle=false, repeat=false dataset['test'], batch_size=config.batch_size, num_workers=config.val_threads, pin_memory=True, collate_fn=t.cfl_collate_fn_factory(False)) elif config.dataset == "S3DIS": trainset = S3DIS( config, train=True, ) valset = S3DIS( config, train=False, ) train_data_loader = torch.utils.data.DataLoader( trainset, batch_size=config.batch_size, sampler=InfSampler(trainset, shuffle=True), # shuffle=true, repeat=true num_workers=config.threads, pin_memory=True, collate_fn=t.cfl_collate_fn_factory( config.train_limit_numpoints)) val_data_loader = torch.utils.data.DataLoader( # shuffle=false, repeat=false valset, batch_size=config.batch_size, num_workers=config.val_threads, pin_memory=True, collate_fn=t.cfl_collate_fn_factory(False)) elif config.dataset == 'Nuscenes': config.xyz_input = False # todo: trainset = Nuscenes( config, train=True, ) valset = Nuscenes( config, train=False, ) train_data_loader = torch.utils.data.DataLoader( trainset, batch_size=config.batch_size, sampler=InfSampler(trainset, shuffle=True), # shuffle=true, repeat=true num_workers=config.threads, pin_memory=True, # collate_fn=t.collate_fn_BEV, # used when cylinder voxelize collate_fn=t.cfl_collate_fn_factory(False)) val_data_loader = torch.utils.data.DataLoader( # shuffle=false, repeat=false valset, batch_size=config.batch_size, num_workers=config.val_threads, pin_memory=True, # collate_fn=t.collate_fn_BEV, collate_fn=t.cfl_collate_fn_factory(False)) else: print('Dataset {} not supported').format(config.dataset) raise NotImplementedError # Setting up num_in_channel and num_labels if train_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 num_labels = train_data_loader.dataset.NUM_LABELS # it = iter(train_data_loader) # for _ in range(100): # data = it.__next__() # print(data) else: # not config.is_train val_DatasetClass = load_dataset('ScannetDatasetWholeScene_evaluation') if config.dataset == 'ScannetSparseVoxelizationDataset': if config.is_export: # when export, we need to export the train results too train_data_loader = initialize_data_loader( DatasetClass, config, phase=config.train_phase, threads=config.threads, augment_data=True, elastic_distortion=config. train_elastic_distortion, # DEBUG: not sure about this shuffle=False, repeat=False, batch_size=config.batch_size, limit_numpoints=config.train_limit_numpoints) # the valid like, no aug data # train_data_loader = initialize_data_loader( # DatasetClass, # config, # threads=config.val_threads, # phase=config.train_phase, # augment_data=False, # elastic_distortion=config.test_elastic_distortion, # shuffle=False, # repeat=False, # batch_size=config.val_batch_size, # limit_numpoints=False) val_data_loader = initialize_data_loader( DatasetClass, config, threads=config.val_threads, phase=config.val_phase, augment_data=False, elastic_distortion=config.test_elastic_distortion, shuffle=False, repeat=False, batch_size=config.val_batch_size, limit_numpoints=False) if val_data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = val_data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 num_labels = val_data_loader.dataset.NUM_LABELS elif config.dataset == 'ScannetDataset': '''when using scannet-point, use val instead of test''' point_scannet = True valset = val_DatasetClass( root= '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles', scene_list_dir= '/data/eva_share_users/zhaotianchen/scannet/raw/metadata', split='eval', block_points=config.num_points, delta=1.0, with_norm=False, ) val_data_loader = torch.utils.data.DataLoader( dataset=valset, # num_workers=config.threads, num_workers= 0, # for loading big pth file, should use single-thread batch_size=config.val_batch_size, # collate_fn=collate_fn, # input points, should not have collate-fn worker_init_fn=_init_fn, ) num_labels = val_data_loader.dataset.NUM_LABELS num_in_channel = 3 elif config.dataset == "SemanticKITTI": dataset = SemanticKITTI(root=config.semantic_kitti_path, num_points=None, voxel_size=config.voxel_size, submit=config.submit) val_data_loader = torch.utils.data.DataLoader( # shuffle=false, repeat=false dataset['test'], batch_size=config.val_batch_size, num_workers=config.val_threads, pin_memory=True, collate_fn=t.cfl_collate_fn_factory(False)) num_in_channel = 4 num_labels = 19 elif config.dataset == 'S3DIS': config.xyz_input = False trainset = S3DIS( config, train=True, ) valset = S3DIS( config, train=False, ) train_data_loader = torch.utils.data.DataLoader( trainset, batch_size=config.batch_size, sampler=InfSampler(trainset, shuffle=True), # shuffle=true, repeat=true num_workers=config.threads, pin_memory=True, collate_fn=t.cfl_collate_fn_factory( config.train_limit_numpoints)) val_data_loader = torch.utils.data.DataLoader( # shuffle=false, repeat=false valset, batch_size=config.batch_size, num_workers=config.val_threads, pin_memory=True, collate_fn=t.cfl_collate_fn_factory(False)) num_in_channel = 9 num_labels = 13 elif config.dataset == 'Nuscenes': config.xyz_input = False trainset = Nuscenes( config, train=True, ) valset = Nuscenes( config, train - False, ) train_data_loader = torch.utils.data.DataLoader( trainset, batch_size=config.batch_size, sampler=InfSampler(trainset, shuffle=True), # shuffle=true, repeat=true num_workers=config.threads, pin_memory=True, # collate_fn=t.collate_fn_BEV, collate_fn=t.cfl_collate_fn_factory(False)) val_data_loader = torch.utils.data.DataLoader( # shuffle=false, repeat=false valset, batch_size=config.batch_size, num_workers=config.val_threads, pin_memory=True, # collate_fn=t.collate_fn_BEV, collate_fn=t.cfl_collate_fn_factory(False)) num_in_channel = 5 num_labels = 16 else: print('Dataset {} not supported').format(config.dataset) raise NotImplementedError logging.info('===> Building model') # if config.model == 'PointTransformer' or config.model == 'MixedTransformer': if config.model == 'PointTransformer': config.pure_point = True NetClass = load_model(config.model) if config.pure_point: model = NetClass(config, num_class=num_labels, N=config.num_points, normal_channel=num_in_channel) else: if config.model == 'MixedTransformer': model = NetClass(config, num_class=num_labels, N=config.num_points, normal_channel=num_in_channel) elif config.model == 'MinkowskiVoxelTransformer': model = NetClass(config, num_in_channel, num_labels) elif config.model == 'MinkowskiTransformerNet': model = NetClass(config, num_in_channel, num_labels) elif "Res" in config.model: model = NetClass(num_in_channel, num_labels, config) else: model = NetClass(num_in_channel, num_labels, config) logging.info('===> Number of trainable parameters: {}: {}M'.format( NetClass.__name__, count_parameters(model) / 1e6)) if hasattr(model, "block1"): if hasattr(model.block1[0], 'h'): h = model.block1[0].h vec_dim = model.block1[0].vec_dim else: h = None vec_dim = None else: h = None vec_dim = None # logging.info('===> Model Args:\n PLANES: {} \n LAYERS: {}\n HEADS: {}\n Vec-dim: {}\n'.format(model.PLANES, model.LAYERS, h, vec_dim)) logging.info(model) # Set the number of threads # ME.initialize_nthreads(12, D=3) model = model.to(device) if config.weights == 'modelzoo': # Load modelzoo weights if possible. logging.info('===> Loading modelzoo weights') model.preload_modelzoo() # Load weights if specified by the parameter. elif config.weights.lower() != 'none': logging.info('===> Loading weights: ' + config.weights) state = torch.load(config.weights) # delete the keys containing the 'attn' since it raises size mismatch d_ = { k: v for k, v in state['state_dict'].items() if '_map' not in k } # debug: sometiems model conmtains 'map_qk' which is not right for naming a module, since 'map' are always buffers d = {} for k in d_.keys(): if 'module.' in k: d[k.replace('module.', '')] = d_[k] else: d[k] = d_[k] # del d_ if config.weights_for_inner_model: model.model.load_state_dict(d) else: if config.lenient_weight_loading: matched_weights = load_state_with_same_shape( model, state['state_dict']) model_dict = model.state_dict() model_dict.update(matched_weights) model.load_state_dict(model_dict) else: model.load_state_dict(d, strict=True) if config.is_debug: check_data(model, train_data_loader, val_data_loader, config) return None elif config.is_train: if hasattr(config, 'distill') and config.distill: assert point_scannet is not True # only support whole scene for no train_distill(model, train_data_loader, val_data_loader, config) if config.multiprocess: if point_scannet: raise NotImplementedError else: train_mp(NetClass, train_data_loader, val_data_loader, config) else: if point_scannet: train_point(model, train_data_loader, val_data_loader, config) else: train(model, train_data_loader, val_data_loader, config) elif config.is_export: if point_scannet: raise NotImplementedError else: # only support the whole-scene-style for now test(model, train_data_loader, config, save_pred=True, split='train') test(model, val_data_loader, config, save_pred=True, split='val') else: assert config.multiprocess == False # if test for submission, make a submit directory at current directory submit_dir = os.path.join(os.getcwd(), 'submit', 'sequences') if config.submit and not os.path.exists(submit_dir): os.makedirs(submit_dir) print("Made submission directory: " + submit_dir) if point_scannet: test_points(model, val_data_loader, config) else: test(model, val_data_loader, config, submit_dir=submit_dir)