def test(model, data_loader, config, transform_data_fn=None, has_gt=True): device = get_torch_device(config.is_cuda) dataset = data_loader.dataset num_labels = dataset.NUM_LABELS global_timer, data_timer, iter_timer = Timer(), Timer(), Timer() criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) losses, scores, ious = AverageMeter(), AverageMeter(), 0 aps = np.zeros((0, num_labels)) hist = np.zeros((num_labels, num_labels)) logging.info('===> Start testing') global_timer.tic() data_iter = data_loader.__iter__() max_iter = len(data_loader) max_iter_unique = max_iter # Fix batch normalization running mean and std model.eval() # Clear cache (when run in val mode, cleanup training cache) torch.cuda.empty_cache() if config.save_prediction or config.test_original_pointcloud: if config.save_prediction: save_pred_dir = config.save_pred_dir os.makedirs(save_pred_dir, exist_ok=True) else: save_pred_dir = tempfile.mkdtemp() if os.listdir(save_pred_dir): raise ValueError(f'Directory {save_pred_dir} not empty. ' 'Please remove the existing prediction.') with torch.no_grad(): for iteration in range(max_iter): data_timer.tic() if config.return_transformation: coords, input, target, transformation = data_iter.next() else: coords, input, target = data_iter.next() transformation = None data_time = data_timer.toc(False) # Preprocess input iter_timer.tic() if config.wrapper_type != 'None': color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) # Feed forward inputs = (sinput, ) if config.wrapper_type == 'None' else (sinput, coords, color) soutput = model(*inputs) output = soutput.F pred = get_prediction(dataset, output, target).int() iter_time = iter_timer.toc(False) if config.save_prediction or config.test_original_pointcloud: save_predictions(coords, pred, transformation, dataset, config, iteration, save_pred_dir) if has_gt: if config.evaluate_original_pointcloud: raise NotImplementedError('pointcloud') output, pred, target = permute_pointcloud( coords, pointcloud, transformation, dataset.label_map, output, pred) target_np = target.numpy() num_sample = target_np.shape[0] target = target.to(device) cross_ent = criterion(output, target.long()) losses.update(float(cross_ent), num_sample) scores.update(precision_at_one(pred, target), num_sample) hist += fast_hist(pred.cpu().numpy().flatten(), target_np.flatten(), num_labels) ious = per_class_iu(hist) * 100 prob = torch.nn.functional.softmax(output, dim=1) ap = average_precision(prob.cpu().detach().numpy(), target_np) aps = np.vstack((aps, ap)) # Due to heavy bias in class, there exists class with no test label at all with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) ap_class = np.nanmean(aps, 0) * 100. if iteration % config.test_stat_freq == 0 and iteration > 0: reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) class_names = dataset.get_classnames() print_info(iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if iteration % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() global_time = global_timer.toc(False) reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) class_names = dataset.get_classnames() print_info(iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if config.test_original_pointcloud: logging.info('===> Start testing on original pointcloud space.') dataset.test_pointcloud(save_pred_dir) logging.info("Finished test. Elapsed time: {:.4f}".format(global_time)) return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean( per_class_iu(hist)) * 100
def test(model, data_loader, config, transform_data_fn=None, has_gt=True, validation=None, epoch=None): device = get_torch_device(config.is_cuda) dataset = data_loader.dataset num_labels = dataset.NUM_LABELS global_timer, data_timer, iter_timer = Timer(), Timer(), Timer() criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) alpha, gamma, eps = 1, 2, 1e-6 # Focal Loss parameters losses, scores, ious = AverageMeter(), AverageMeter(), 0 aps = np.zeros((0, num_labels)) hist = np.zeros((num_labels, num_labels)) if not config.is_train: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) model.load_state_dict(state['state_dict']) logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) if validation: logging.info('===> Start validating') else: logging.info('===> Start testing') global_timer.tic() data_iter = data_loader.__iter__() max_iter = len(data_loader) max_iter_unique = max_iter all_preds, all_labels, batch_losses, batch_loss = [], [], {}, 0 # Fix batch normalization running mean and std model.eval() # Clear cache (when run in val mode, cleanup training cache) torch.cuda.empty_cache() if config.save_prediction or config.test_original_pointcloud: if config.save_prediction: save_pred_dir = config.save_pred_dir os.makedirs(save_pred_dir, exist_ok=True) else: save_pred_dir = tempfile.mkdtemp() if os.listdir(save_pred_dir): raise ValueError(f'Directory {save_pred_dir} not empty. ' 'Please remove the existing prediction.') with torch.no_grad(): for iteration in range(max_iter): data_timer.tic() if config.return_transformation: coords, input, target, transformation = data_iter.next() else: coords, input, target = data_iter.next() transformation = None data_time = data_timer.toc(False) # Preprocess input iter_timer.tic() if config.wrapper_type != 'None': color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) # Feed forward inputs = (sinput, ) if config.wrapper_type == 'None' else (sinput, coords, color) soutput = model(*inputs) output = soutput.F pred = get_prediction(dataset, output, target).int() iter_time = iter_timer.toc(False) all_preds.append(pred.cpu().detach().numpy()) all_labels.append(target.cpu().detach().numpy()) if config.save_prediction or config.test_original_pointcloud: save_predictions(coords, pred, transformation, dataset, config, iteration, save_pred_dir) if has_gt: if config.evaluate_original_pointcloud: raise NotImplementedError('pointcloud') output, pred, target = permute_pointcloud( coords, pointcloud, transformation, dataset.label_map, output, pred) target_np = target.numpy() num_sample = target_np.shape[0] target = target.to(device) """# focal loss input_soft = nn.functional.softmax(output, dim=1) + eps focal_weight = torch.pow(-input_soft + 1., gamma) loss = (-alpha * focal_weight * torch.log(input_soft)).mean()""" loss = criterion(output, target.long()) batch_loss += loss losses.update(float(loss), num_sample) scores.update(precision_at_one(pred, target), num_sample) hist += fast_hist(pred.cpu().numpy().flatten(), target_np.flatten(), num_labels) ious = per_class_iu(hist) * 100 prob = torch.nn.functional.softmax(output, dim=1) ap = average_precision(prob.cpu().detach().numpy(), target_np) aps = np.vstack((aps, ap)) # Due to heavy bias in class, there exists class with no test label at all with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) ap_class = np.nanmean(aps, 0) * 100. if iteration % config.test_stat_freq == 0 and iteration > 0: preds = np.concatenate(all_preds) targets = np.concatenate(all_labels) to_ignore = [ i for i in range(len(targets)) if targets[i] == 255 ] preds_trunc = [ preds[i] for i in range(len(preds)) if i not in to_ignore ] targets_trunc = [ targets[i] for i in range(len(targets)) if i not in to_ignore ] cm = confusion_matrix(targets_trunc, preds_trunc, normalize='true') np.savetxt(config.log_dir + '/cm_epoch_{0}.txt'.format(epoch), cm) reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) class_names = dataset.get_classnames() print_info(iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if iteration % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() batch_losses[epoch] = batch_loss global_time = global_timer.toc(False) reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) class_names = dataset.get_classnames() print_info(iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if not config.is_train: preds = np.concatenate(all_preds) targets = np.concatenate(all_labels) to_ignore = [i for i in range(len(targets)) if targets[i] == 255] preds_trunc = [ preds[i] for i in range(len(preds)) if i not in to_ignore ] targets_trunc = [ targets[i] for i in range(len(targets)) if i not in to_ignore ] cm = confusion_matrix(targets_trunc, preds_trunc, normalize='true') np.savetxt(config.log_dir + '/cm.txt', cm) if config.test_original_pointcloud: logging.info('===> Start testing on original pointcloud space.') dataset.test_pointcloud(save_pred_dir) logging.info("Finished test. Elapsed time: {:.4f}".format(global_time)) if validation: loss_file_name = "/val_loss.txt" with open(config.log_dir + loss_file_name, 'a') as val_loss_file: for key in batch_losses: val_loss_file.writelines('{0}, {1}\n'.format( batch_losses[key], key)) val_loss_file.close() return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean( per_class_iu(hist)) * 100, batch_losses else: return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean( per_class_iu(hist)) * 100
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): device = config.device_id distributed = get_world_size() > 1 # Set up the train flag for batch normalization model.train() # Configuration writer = SummaryWriter(log_dir=config.log_dir) data_timer, iter_timer = Timer(), Timer() fw_timer, bw_timer, ddp_timer = Timer(), Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() fw_time_avg, bw_time_avg, ddp_time_avg = AverageMeter(), AverageMeter( ), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) writer = SummaryWriter(log_dir=config.log_dir) # Train the network logging.info('===> Start training on {} GPUs, batch-size={}'.format( get_world_size(), config.batch_size * get_world_size())) best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load( checkpoint_fn, map_location=lambda s, l: default_restore_location(s, 'cpu')) curr_iter = state['iteration'] + 1 epoch = state['epoch'] load_state(model, state['state_dict']) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() # (distributed) infinite sampler while is_training: for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss, batch_score = 0, 0, 0 iter_timer.tic() # set random seed for every iteration for trackability _set_seed(config, curr_iter) for sub_iter in range(config.iter_size): # Get training data data_timer.tic() coords, input, target = data_iter.next() # For some networks, making the network invariant to even, odd coords is important coords[:, :3] += (torch.rand(3) * 100).type_as(coords) # Preprocess input color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) data_time += data_timer.toc(False) # Feed forward fw_timer.tic() inputs = (sinput, ) if config.wrapper_type == 'None' else ( sinput, coords, color) # model.initialize_coords(*init_args) soutput = model(*inputs) # The output of the network is not sorted target = target.long().to(device) loss = criterion(soutput.F, target.long()) # Compute and accumulate gradient loss /= config.iter_size pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target) fw_timer.toc(False) bw_timer.tic() # bp the loss loss.backward() bw_timer.toc(False) # gather information logging_output = { 'loss': loss.item(), 'score': score / config.iter_size } ddp_timer.tic() if distributed: logging_output = all_gather_list(logging_output) logging_output = { w: np.mean([a[w] for a in logging_output]) for w in logging_output[0] } batch_loss += logging_output['loss'] batch_score += logging_output['score'] ddp_timer.toc(False) # Update number of steps optimizer.step() scheduler.step() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) fw_time_avg.update(fw_timer.diff) bw_time_avg.update(bw_timer.diff) ddp_time_avg.update(ddp_timer.diff) losses.update(batch_loss, target.size(0)) scores.update(batch_score, target.size(0)) if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Forward time: {:.4f}, Backward time: {:.4f}, DDP time: {:.4f}, Total iter time: {:.4f}".format( scores.avg, data_time_avg.avg, fw_time_avg.avg, bw_time_avg.avg, ddp_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs writer.add_scalar('training/loss', losses.avg, curr_iter) writer.add_scalar('training/precision_at_1', scores.avg, curr_iter) writer.add_scalar('training/learning_rate', scheduler.get_lr()[0], curr_iter) losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() if curr_iter % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() # End of iteration curr_iter += 1 epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() regs, losses, scores = AverageMeter(), AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: # Test loaded ckpt first v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config) checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] # we skip attention maps because the shape won't match because voxel number is different # e.g. copyting a param with shape (23385, 8, 4) to (43529, 8, 4) d = { k: v for k, v in state['state_dict'].items() if 'map' not in k } # handle those attn maps we don't load from saved dict for k in model.state_dict().keys(): if k in d.keys(): continue d[k] = model.state_dict()[k] model.load_state_dict(d) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() if config.dataset == "SemanticKITTI": num_class = 19 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 10 elif config.dataset == "S3DIS": num_class = 13 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq elif config.dataset == "Nuscenes": num_class = 16 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 50 else: num_class = 20 val_freq_ = config.val_freq while is_training: total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() if curr_iter >= config.max_iter: # if curr_iter >= max(config.max_iter, config.epochs*(len(data_loader) // config.iter_size): is_training = False break elif curr_iter >= config.max_iter * (2 / 3): config.val_freq = val_freq_ * 2 # valid more freq on lower half for sub_iter in range(config.iter_size): # Get training data data_timer.tic() pointcloud = None if config.return_transformation: coords, input, target, _, _, pointcloud, transformation, _ = data_iter.next( ) else: coords, input, target, _, _, _ = data_iter.next( ) # ignore unique_map and inverse_map if config.use_aux: assert target.shape[1] == 2 aux = target[:, 1] target = target[:, 0] else: aux = None # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / input[:, :3].max() - 0.5 coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5 # cat xyz into the rgb feature if config.xyz_input: input = torch.cat([coords_norm, input], dim=1) sinput = SparseTensor(input, coords, device=device) starget = SparseTensor( target.unsqueeze(-1).float(), coordinate_map_key=sinput.coordinate_map_key, coordinate_manager=sinput.coordinate_manager, device=device ) # must share the same coord-manager to align for sinput data_time += data_timer.toc(False) # model.initialize_coords(*init_args) # d = {} # d['c'] = sinput.C # d['l'] = starget.F # torch.save('./plot/test-label.pth') # import ipdb; ipdb.set_trace() # Set up profiler # memory_profiler = CUDAMemoryProfiler( # [model, criterion], # filename="cuda_memory.profile" # ) # sys.settrace(memory_profiler) # threading.settrace(memory_profiler) # with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=True) as prof0: if aux is not None: soutput = model(sinput, aux) elif config.enable_point_branch: soutput = model(sinput, iter_=curr_iter / config.max_iter, enable_point_branch=True) else: # label-aux, feed it in as additional reg soutput = model( sinput, iter_=curr_iter / config.max_iter, aux=starget ) # feed in the progress of training for annealing inside the model # The output of the network is not sorted target = target.view(-1).long().to(device) loss = criterion(soutput.F, target.long()) # ====== other loss regs ===== if hasattr(model, 'block1'): cur_loss = torch.tensor([0.], device=device) if hasattr(model.block1[0], 'vq_loss'): if model.block1[0].vq_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].vq_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur vq_loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.block1[0], 'diverse_loss'): if model.block1[0].diverse_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].diverse_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur diverse _loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.block1[0], 'label_reg'): if model.block1[0].label_reg is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].label_reg # m is the nn.Sequential obj, m[0] is the TRBlock # logging.info('Cur Loss: {}, Cur diverse _loss: {}'.format(loss, cur_loss)) loss += cur_loss # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # soutput = model(sinput) # Update number of steps if not config.use_sam: optimizer.step() else: optimizer.first_step(zero_grad=True) soutput = model(sinput, iter_=curr_iter / config.max_iter, aux=starget) criterion(soutput.F, target.long()).backward() optimizer.second_step(zero_grad=True) if config.lr_warmup is None: scheduler.step() else: if curr_iter >= config.lr_warmup: scheduler.step() for g in optimizer.param_groups: g['lr'] = config.lr * (iteration + 1) / config.lr_warmup # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target, ignore_label=-1) regs.update(cur_loss.item(), target.size(0)) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # calc the train-iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target != -1)) | (target == l)).sum() if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) IoU = ((total_correct_class) / (total_iou_deno_class + 1e-6)).mean() * 100. debug_str = "[{}] ===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( config.log_dir.split('/')[-2], epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tIoU {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, IoU.item(), data_time_avg.avg, iter_time_avg.avg) if regs.avg > 0: debug_str += "\n Additional Reg Loss {:.3f}".format( regs.avg) # print(debug_str) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, None, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val", save_inter=True) logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # print("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) v_loss, v_score, v_mAP, val_miou = test(model, val_data_loader, config) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train_distill(model, data_loader, val_data_loader, config, transform_data_fn=None): ''' the distillation training some cfgs here ''' # distill_lambda = 1 # distill_lambda = 0.33 distill_lambda = 0.67 # TWO_STAGE=True: Transformer is first trained with L2 loss to match ResNet's activation, and then it fintunes like normal training on the second stage. # TWO_STAGE=False: Transformer trains with combined loss TWO_STAGE = False # STAGE_PERCENTAGE = 0.7 device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True # TODO: # load the sub-model only # FIXME: some dirty hard-written stuff, only supporting current state tch_model_cls = load_model('Res16UNet18A') tch_model = tch_model_cls(3, 20, config).to(device) # checkpoint_fn = "/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/outputs/ScannetSparseVoxelizationDataset/Res16UNet18A/resnet_base/weights.pth" checkpoint_fn = "/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/outputs/ScannetSparseVoxelizationDataset/Res16UNet18A/Res18A/weights.pth" # voxel-size: 0.05 assert osp.isfile(checkpoint_fn) logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) d = {k: v for k, v in state['state_dict'].items() if 'map' not in k} tch_model.load_state_dict(d) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) if config.resume: raise NotImplementedError # Test loaded ckpt first # checkpoint_fn = config.resume + '/weights.pth' # if osp.isfile(checkpoint_fn): # logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) # state = torch.load(checkpoint_fn) # curr_iter = state['iteration'] + 1 # epoch = state['epoch'] # d = {k:v for k,v in state['state_dict'].items() if 'map' not in k } # model.load_state_dict(d) # if config.resume_optimizer: # scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) # optimizer.load_state_dict(state['optimizer']) # if 'best_val' in state: # best_val_miou = state['best_val'] # best_val_iter = state['best_val_iter'] # logging.info("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_fn, state['epoch'])) # else: # raise ValueError("=> no checkpoint found at '{}'".format(checkpoint_fn)) # test after loading the ckpt v_loss, v_score, v_mAP, v_mIoU = test(tch_model, val_data_loader, config) logging.info('Tch model tested, bes_miou: {}'.format(v_mIoU)) data_iter = data_loader.__iter__() while is_training: num_class = 20 total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) total_iteration = len(data_loader) // config.iter_size for iteration in range(total_iteration): # NOTE: for single stage distillation, L2 loss might be too large at first # so we added a warmup training that don't use L2 loss if iteration < 0: use_distill = False else: use_distill = True # Stage 1 / Stage 2 boundary if TWO_STAGE: stage_boundary = int(total_iteration * STAGE_PERCENTAGE) optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): # Get training data data_timer.tic() if config.return_transformation: coords, input, target, _, _, pointcloud, transformation = data_iter.next( ) else: coords, input, target, _, _ = data_iter.next( ) # ignore unique_map and inverse_map if config.use_aux: assert target.shape[1] == 2 aux = target[:, 1] target = target[:, 0] else: aux = None # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5 # cat xyz into the rgb feature if config.xyz_input: input = torch.cat([coords_norm, input], dim=1) sinput = SparseTensor(input, coords, device=device) # TODO: return both-models # in order to not breaking the valid interface, use a get_loss to get the regsitered loss data_time += data_timer.toc(False) # model.initialize_coords(*init_args) if aux is not None: raise NotImplementedError # flatten ground truth tensor target = target.view(-1).long().to(device) if TWO_STAGE: if iteration < stage_boundary: # Stage 1: train transformer on L2 loss soutput, anchor = model(sinput, save_anchor=True) # Make sure gradient don't flow to teacher model with torch.no_grad(): _, tch_anchor = tch_model(sinput, save_anchor=True) loss = DistillLoss(tch_anchor, anchor) else: # Stage 2: finetune transformer on Cross-Entropy soutput = model(sinput) loss = criterion(soutput.F, target.long()) else: if use_distill: # after warm up soutput, anchor = model(sinput, save_anchor=True) # if pretrained teacher, do not let the grad flow to teacher to update its params with torch.no_grad(): tch_soutput, tch_anchor = tch_model( sinput, save_anchor=True) else: # warming up soutput = model(sinput) # The output of the network is not sorted loss = criterion(soutput.F, target.long()) # Add L2 loss if use distillation if use_distill: distill_loss = DistillLoss(tch_anchor, anchor) * distill_lambda loss += distill_loss # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # Update number of steps optimizer.step() scheduler.step() # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target, ignore_label=-1) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # calc the train-iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target != -1)) | (target == l)).sum() if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "[{}] ===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( config.log_dir, epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) if use_distill and not TWO_STAGE: logging.info('Loss {} Distill Loss:{}'.format( loss, distill_loss)) # Reset timers data_time_avg.reset() iter_time_avg.reset() losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, None, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val", save_inter=True) logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) v_loss, v_score, v_mAP, val_miou = test(model, val_data_loader, config) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train_point(model, data_loader, val_data_loader, config, transform_data_fn=None): device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=-1) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] d = { k: v for k, v in state['state_dict'].items() if 'map' not in k } model.load_state_dict(d) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() while is_training: num_class = 20 total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): # Get training data data = data_iter.next() points, target, sample_weight = data if config.pure_point: sinput = points.transpose(1, 2).cuda().float() # DEBUG: use the discrete coord for point-based ''' feats = torch.unbind(points[:,:,:], dim=0) voxel_size = config.voxel_size coords = torch.unbind(points[:,:,:3]/voxel_size, dim=0) # 0.05 is the voxel-size coords, feats= ME.utils.sparse_collate(coords, feats) # assert feats.reshape([16, 4096, -1]) == points[:,:,3:] points_ = ME.TensorField(features=feats.float(), coordinates=coords, device=device) tmp_voxel = points_.sparse() sinput_ = tmp_voxel.slice(points_) sinput = torch.cat([sinput_.C[:,1:]*config.voxel_size, sinput_.F[:,3:]],dim=1).reshape([config.batch_size, config.num_points, 6]) # sinput = sinput_.F.reshape([config.batch_size, config.num_points, 6]) sinput = sinput.transpose(1,2).cuda().float() # sinput = torch.cat([coords[:,1:], feats],dim=1).reshape([config.batch_size, config.num_points, 6]) # sinput = sinput.transpose(1,2).cuda().float() ''' # For some networks, making the network invariant to even, odd coords is important # coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input # if config.normalize_color: # feats = feats / 255. - 0.5 # torch.save(points[:,:,:3], './sandbox/tensorfield-c.pth') # torch.save(points_.C, './sandbox/points-c.pth') else: # feats = torch.unbind(points[:,:,3:], dim=0) # WRONG: should also feed in xyz as inupt feature voxel_size = config.voxel_size coords = torch.unbind(points[:, :, :3] / voxel_size, dim=0) # 0.05 is the voxel-size # Normalize the xyz in feature # points[:,:,:3] = points[:,:,:3] / points[:,:,:3].mean() feats = torch.unbind(points[:, :, :], dim=0) coords, feats = ME.utils.sparse_collate(coords, feats) # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input # if config.normalize_color: # feats = feats / 255. - 0.5 # they are the same points_ = ME.TensorField(features=feats.float(), coordinates=coords, device=device) # points_1 = ME.TensorField(features=feats.float(), coordinates=coords, device=device, quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE) # points_2 = ME.TensorField(features=feats.float(), coordinates=coords, device=device, quantization_mode=ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE) sinput = points_.sparse() data_time += data_timer.toc(False) B, npoint = target.shape # model.initialize_coords(*init_args) soutput = model(sinput) if config.pure_point: soutput = soutput.reshape([B * npoint, -1]) else: soutput = soutput.slice(points_).F # s1 = soutput.slice(points_) # print(soutput.quantization_mode) # soutput.quantization_mode = ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE # s2 = soutput.slice(points_) # The output of the network is not sorted target = (target - 1).view(-1).long().to(device) # catch NAN if torch.isnan(soutput).sum() > 0: import ipdb ipdb.set_trace() loss = criterion(soutput, target) if torch.isnan(loss).sum() > 0: import ipdb ipdb.set_trace() loss = (loss * sample_weight.to(device)).mean() # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # print(model.input_mlp[0].weight.max()) # print(model.input_mlp[0].weight.grad.max()) # Update number of steps optimizer.step() scheduler.step() # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput, target) score = precision_at_one(pred, target, ignore_label=-1) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # Calc the iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target >= 0)) | (target == l)).sum() if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation: # for point-based should use alternate dataloader for eval # if curr_iter % config.val_freq == 0: # val_miou = test_points(model, val_data_loader, None, curr_iter, config, transform_data_fn) # if val_miou > best_val_miou: # best_val_miou = val_miou # best_val_iter = curr_iter # checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, # "best_val") # logging.info("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter)) # # Recover back # model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) test_points(model, val_data_loader, config) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): all_losses = [] device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration writer = SummaryWriter(log_dir=config.log_dir) data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores, batch_losses = AverageMeter(), AverageMeter(), {} optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) alpha, gamma, eps = 1, 2, 1e-6 writer = SummaryWriter(log_dir=config.log_dir) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] model.load_state_dict(state['state_dict']) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() while is_training: print( "********************************** epoch N° {0} ************************" .format(epoch)) for iteration in range(len(data_loader) // config.iter_size): print("####### Iteration N° {0}".format(iteration)) optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): print("------------- Sub_iteration N° {0}".format(sub_iter)) # Get training data data_timer.tic() coords, input, target = data_iter.next() print("len of coords : {0}".format(len(coords))) # For some networks, making the network invariant to even, odd coords is important coords[:, :3] += (torch.rand(3) * 100).type_as(coords) # Preprocess input color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) data_time += data_timer.toc(False) # Feed forward inputs = (sinput, ) if config.wrapper_type == 'None' else ( sinput, coords, color) # model.initialize_coords(*init_args) soutput = model(*inputs) # The output of the network is not sorted target = target.long().to(device) print("count of classes : {0}".format( np.unique(target.cpu().numpy(), return_counts=True))) print("target : {0}\ntarget_len : {1}".format( target, len(target))) print("target [0]: {0}".format(target[0])) input_soft = nn.functional.softmax(soutput.F, dim=1) + eps print("input_soft[0] : {0}".format(input_soft[0])) focal_weight = torch.pow(-input_soft + 1., gamma) print("focal_weight : {0}\nweight[0] : {1}".format( focal_weight, focal_weight[0])) focal_loss = (-alpha * focal_weight * torch.log(input_soft)).mean() loss = criterion(soutput.F, target.long()) print("focal_loss :{0}\nloss : {1}".format(focal_loss, loss)) # Compute and accumulate gradient loss /= config.iter_size #batch_loss += loss batch_loss += loss.item() print("batch_loss : {0}".format(batch_loss)) loss.backward() # Update number of steps optimizer.step() scheduler.step() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Total iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs writer.add_scalar('training/loss', losses.avg, curr_iter) writer.add_scalar('training/precision_at_1', scores.avg, curr_iter) writer.add_scalar('training/learning_rate', scheduler.get_lr()[0], curr_iter) losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) # Validation if curr_iter % config.val_freq == 0: val_miou, val_losses = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn, epoch) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() if curr_iter % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() batch_losses[epoch] = batch_loss # End of iteration curr_iter += 1 with open(config.log_dir + "/train_loss.txt", 'a') as train_loss_log: train_loss_log.writelines('{0}, {1}\n'.format( batch_losses[epoch], epoch)) train_loss_log.close() epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn, epoch)[0] if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train_worker(gpu, num_devices, NetClass, data_loader, val_data_loader, config, transform_data_fn=None): if gpu is not None: print("Use GPU: {} for training".format(gpu)) rank = gpu addr = 23491 dist.init_process_group(backend="nccl", init_method="tcp://127.0.0.1:{}".format(addr), world_size=num_devices, rank=rank) # replace with DistributedSampler if config.multiprocess: from lib.dataloader_dist import InfSampler sampler = InfSampler(data_loader.dataset) data_loader = DataLoader(dataset=data_loader.dataset, num_workers=data_loader.num_workers, batch_size=data_loader.batch_size, collate_fn=data_loader.collate_fn, worker_init_fn=data_loader.worker_init_fn, sampler=sampler) if data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 num_labels = data_loader.dataset.NUM_LABELS # load model if config.pure_point: model = NetClass(num_class=config.num_labels, N=config.num_points, normal_channel=config.num_in_channel) else: if config.model == 'MixedTransformer': model = NetClass(config, num_class=num_labels, N=config.num_points, normal_channel=num_in_channel) elif config.model == 'MinkowskiVoxelTransformer': model = NetClass(config, num_in_channel, num_labels) elif config.model == 'MinkowskiTransformerNet': model = NetClass(config, num_in_channel, num_labels) elif "Res" in config.model: model = NetClass(num_in_channel, num_labels, config) else: model = NetClass(num_in_channel, num_labels, config) if config.weights == 'modelzoo': model.preload_modelzoo() elif config.weights.lower() != 'none': state = torch.load(config.weights) # delete the keys containing the attn since it raises size mismatch d = {k: v for k, v in state['state' '_dict'].items() if 'map' not in k} if config.weights_for_inner_model: model.model.load_state_dict(d) else: if config.lenient_weight_loading: matched_weights = load_state_with_same_shape( model, state['state_dict']) model_dict = model.state_dict() model_dict.update(matched_weights) model.load_state_dict(model_dict) else: model.load_state_dict(d, strict=False) torch.cuda.set_device(gpu) model.cuda(gpu) # use model with DDP model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[gpu], find_unused_parameters=False) # Synchronized batch norm model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() regs, losses, scores = AverageMeter(), AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) # Train the network if rank == 0: setup_logger(config) logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: # Test loaded ckpt first v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config) checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] # we skip attention maps because the shape won't match because voxel number is different # e.g. copyting a param with shape (23385, 8, 4) to (43529, 8, 4) d = { k: v for k, v in state['state_dict'].items() if 'map' not in k } # handle those attn maps we don't load from saved dict for k in model.state_dict().keys(): if k in d.keys(): continue d[k] = model.state_dict()[k] model.load_state_dict(d) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() device = gpu # multitrain fed in the device if config.dataset == "SemanticKITTI": num_class = 19 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 10 # origianl val_freq_ elif config.dataset == 'S3DIS': num_class = 13 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq elif config.dataset == "Nuscenes": num_class = 16 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 50 else: val_freq_ = config.val_freq num_class = 20 while is_training: total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() if curr_iter >= config.max_iter: # if curr_iter >= max(config.max_iter, config.epochs*(len(data_loader) // config.iter_size): is_training = False break elif curr_iter >= config.max_iter * (2 / 3): config.val_freq = val_freq_ * 2 # valid more freq on lower half for sub_iter in range(config.iter_size): # Get training data data_timer.tic() if config.return_transformation: coords, input, target, _, _, pointcloud, transformation = data_iter.next( ) else: coords, input, target, _, _ = data_iter.next( ) # ignore unique_map and inverse_map if config.use_aux: assert target.shape[1] == 2 aux = target[:, 1] target = target[:, 0] else: aux = None # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / input[:, :3].max() - 0.5 coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5 # cat xyz into the rgb feature if config.xyz_input: input = torch.cat([coords_norm, input], dim=1) # print(device) sinput = SparseTensor(input, coords, device=device) # d = {} # d['coord'] = sinput.C # d['feat'] = sinput.F # torch.save(d, 'voxel.pth') # import ipdb; ipdb.set_trace() data_time += data_timer.toc(False) # model.initialize_coords(*init_args) if aux is not None: soutput = model(sinput, aux) elif config.enable_point_branch: soutput = model(sinput, iter_=curr_iter / config.max_iter, enable_point_branch=True) else: soutput = model( sinput, iter_=curr_iter / config.max_iter ) # feed in the progress of training for annealing inside the model # soutput = model(sinput) # The output of the network is not sorted target = target.view(-1).long().to(device) loss = criterion(soutput.F, target.long()) # ====== other loss regs ===== cur_loss = torch.tensor([0.], device=device) if hasattr(model, 'module.block1'): cur_loss = torch.tensor([0.], device=device) if hasattr(model.module.block1[0], 'vq_loss'): if model.block1[0].vq_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].vq_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur vq_loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.module.block1[0], 'diverse_loss'): if model.block1[0].diverse_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].diverse_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur diverse _loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.module.block1[0], 'label_reg'): if model.block1[0].label_reg is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].label_reg # m is the nn.Sequential obj, m[0] is the TRBlock # logging.info('Cur Loss: {}, Cur diverse _loss: {}'.format(loss, cur_loss)) loss += cur_loss # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() if not config.use_sam: loss.backward() else: with model.no_sync(): loss.backward() # Update number of steps if not config.use_sam: optimizer.step() else: optimizer.first_step(zero_grad=True) soutput = model(sinput, iter_=curr_iter / config.max_iter, aux=starget) criterion(soutput.F, target.long()).backward() optimizer.second_step(zero_grad=True) if config.lr_warmup is None: scheduler.step() else: if curr_iter >= config.lr_warmup: scheduler.step() else: for g in optimizer.param_groups: g['lr'] = config.lr * (iteration + 1) / config.lr_warmup # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target, ignore_label=-1) regs.update(cur_loss.item(), target.size(0)) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # calc the train-iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target != -1)) | (target == l)).sum() if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(g['lr']) for g in optimizer.param_groups]) IoU = ((total_correct_class) / (total_iou_deno_class + 1e-6)).mean() * 100. debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tIoU {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, IoU.item(), data_time_avg.avg, iter_time_avg.avg) if regs.avg > 0: debug_str += "\n Additional Reg Loss {:.3f}".format( regs.avg) if rank == 0: logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs losses.reset() scores.reset() # only save status on the 1st gpu if rank == 0: # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, None, curr_iter, config, transform_data_fn ) # feedin None for SummaryWriter args if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val", save_inter=True) if rank == 0: logging.info( "Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) if rank == 0: logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model if rank == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) v_loss, v_score, v_mAP, val_mIoU = test(model, val_data_loader, config) if val_miou > best_val_miou and rank == 0: best_val_miou = val_miou best_val_iter = curr_iter logging.info("Final best miou: {} at iter {} ".format( val_miou, curr_iter)) checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def test(model, data_loader, config, transform_data_fn=None, has_gt=True, save_pred=False, split=None, submit_dir=None): device = get_torch_device(config.is_cuda) dataset = data_loader.dataset num_labels = dataset.NUM_LABELS global_timer, data_timer, iter_timer = Timer(), Timer(), Timer() criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) losses, scores, ious = AverageMeter(), AverageMeter(), 0 aps = np.zeros((0, num_labels)) hist = np.zeros((num_labels, num_labels)) # some cfgs concerning the usage of instance-level information config.save_pred = save_pred if split is not None: assert save_pred if config.save_pred: save_dict = {} save_dict['pred'] = [] save_dict['coord'] = [] logging.info('===> Start testing') global_timer.tic() data_iter = data_loader.__iter__() max_iter = len(data_loader) max_iter_unique = max_iter # Fix batch normalization running mean and std model.eval() # Clear cache (when run in val mode, cleanup training cache) torch.cuda.empty_cache() # semantic kitti label inverse mapping if config.submit: remap_lut = Remap().getRemapLUT() with torch.no_grad(): # Calc of the iou total_correct = np.zeros(num_labels) total_seen = np.zeros(num_labels) total_positive = np.zeros(num_labels) point_nums = np.zeros([19]) for iteration in range(max_iter): data_timer.tic() if config.return_transformation: coords, input, target, unique_map_list, inverse_map_list, pointcloud, transformation, filename = data_iter.next() else: coords, input, target, unique_map_list, inverse_map_list, filename = data_iter.next() data_time = data_timer.toc(False) if config.use_aux: assert target.shape[1] == 2 aux = target[:,1] target = target[:,0] else: aux = None # Preprocess input iter_timer.tic() if config.normalize_color: input[:, :3] = input[:, :3] / input[:,:3].max() - 0.5 coords_norm = coords[:,1:] / coords[:,1:].max() - 0.5 XYZ_INPUT = config.xyz_input # cat xyz into the rgb feature if XYZ_INPUT: input = torch.cat([coords_norm, input], dim=1) sinput = ME.SparseTensor(input, coords, device=device) # Feed forward if aux is not None: soutput = model(sinput) else: soutput = model(sinput, iter_ = iteration / max_iter, enable_point_branch=config.enable_point_branch) output = soutput.F if torch.isnan(output).sum() > 0: import ipdb; ipdb.set_trace() pred = get_prediction(dataset, output, target).int() assert sum([int(t.shape[0]) for t in unique_map_list]) == len(pred), "number of points in unique_map doesn't match predition, do not enable preprocessing" iter_time = iter_timer.toc(False) if config.save_pred or config.submit: # troublesome processing for splitting each batch's data, and export batch_ids = sinput.C[:,0] splits_at = torch.stack([torch.where(batch_ids == i)[0][-1] for i in torch.unique(batch_ids)]).int() splits_at = splits_at + 1 splits_at_leftshift_one = splits_at.roll(shifts=1) splits_at_leftshift_one[0] = 0 # len_per_batch = splits_at - splits_at_leftshift_one len_sum = 0 batch_id = 0 for start, end in zip(splits_at_leftshift_one, splits_at): len_sum += len(pred[int(start):int(end)]) pred_this_batch = pred[int(start):int(end)] coord_this_batch = pred[int(start):int(end)] if config.save_pred: save_dict['pred'].append(pred_this_batch[inverse_map_list[batch_id]]) else: # save submit result submission_path = filename[batch_id].replace(config.semantic_kitti_path, submit_dir).replace('velodyne', 'predictions').replace('.bin', '.label') parent_dir = Path(submission_path).parent.absolute() if not os.path.exists(parent_dir): os.makedirs(parent_dir) label_pred = pred_this_batch[inverse_map_list[batch_id]].cpu().numpy() label_pred = remap_lut[label_pred].astype(np.uint32) label_pred.tofile(submission_path) print(submission_path) batch_id += 1 assert len_sum == len(pred) # Unpack it to original length REVERT_WHOLE_POINTCLOUD = True # print('{}/{}'.format(iteration, max_iter)) if REVERT_WHOLE_POINTCLOUD: whole_pred = [] whole_target = [] for batch_ in range(config.batch_size): batch_mask_ = (soutput.C[:,0] == batch_).cpu().numpy() if batch_mask_.sum() == 0: # for empty batch, skip em continue try: whole_pred_ = soutput.F[batch_mask_][inverse_map_list[batch_]] except: import ipdb; ipdb.set_trace() whole_target_ = target[batch_mask_][inverse_map_list[batch_]] whole_pred.append(whole_pred_) whole_target.append(whole_target_) whole_pred = torch.cat(whole_pred, dim=0) whole_target = torch.cat(whole_target, dim=0) pred = get_prediction(dataset, whole_pred, whole_target).int() output = whole_pred target = whole_target if has_gt: target_np = target.numpy() num_sample = target_np.shape[0] target = target.to(device) output = output.to(device) cross_ent = criterion(output, target.long()) losses.update(float(cross_ent), num_sample) scores.update(precision_at_one(pred, target), num_sample) hist += fast_hist(pred.cpu().numpy().flatten(), target_np.flatten(), num_labels) # within fast hist, mark label should >=0 & < num_label to filter out 255 / -1 ious = per_class_iu(hist) * 100 prob = torch.nn.functional.softmax(output, dim=-1) pred = pred[target != -1] target = target[target != -1] # for _ in range(num_labels): # debug for SemKITTI: spvnas way of calc miou # total_seen[_] += torch.sum(target == _) # total_correct[_] += torch.sum((pred == target) & (target == _)) # total_positive[_] += torch.sum(pred == _) # ious_ = [] # for _ in range(num_labels): # if total_seen[_] == 0: # ious_.append(1) # else: # ious_.append(total_correct[_]/(total_seen[_] + total_positive[_] - total_correct[_])) # ious_ = torch.stack(ious_, dim=-1).cpu().numpy()*100 # print(np.nanmean(per_class_iu(hist)), np.nanmean(ious_)) # ious = np.array(ious_)*100 # calc the ratio of total points # for i_ in range(19): # point_nums[i_] += (target == i_).sum().detach() # skip calculating aps ap = average_precision(prob.cpu().detach().numpy(), target_np) aps = np.vstack((aps, ap)) # Due to heavy bias in class, there exists class with no test label at all with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) ap_class = np.nanmean(aps, 0) * 100. if iteration % config.test_stat_freq == 0 and iteration > 0 and not config.submit: reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) # dirty fix for semnaticcKITTI has no getclassnames if hasattr(dataset, "class_names"): class_names = dataset.get_classnames() else: # semnantic KITTI class_names = None print_info( iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if iteration % 5 == 0: # Clear cache torch.cuda.empty_cache() if config.save_pred: # torch.save(save_dict, os.path.join(config.log_dir, 'preds_{}_with_coord.pth'.format(split))) torch.save(save_dict, os.path.join(config.log_dir, 'preds_{}.pth'.format(split))) print("===> saved prediction result") global_time = global_timer.toc(False) save_map(model, config) reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) if hasattr(dataset, "class_names"): class_names = dataset.get_classnames() else: class_names = None print_info( iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) logging.info("Finished test. Elapsed time: {:.4f}".format(global_time)) # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean(per_class_iu(hist)) * 100
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration writer = SummaryWriter(log_dir=config.log_dir) data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) writer = SummaryWriter(log_dir=config.log_dir) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] model.load_state_dict(state['state_dict']) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() while is_training: for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): # Get training data data_timer.tic() if config.return_transformation: coords, input, target, pointcloud, transformation = data_iter.next( ) else: coords, input, target = data_iter.next() # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) data_time += data_timer.toc(False) # model.initialize_coords(*init_args) soutput = model(sinput) # The output of the network is not sorted target = target.long().to(device) loss = criterion(soutput.F, target.long()) # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # Update number of steps optimizer.step() scheduler.step() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs writer.add_scalar('training/loss', losses.avg, curr_iter) writer.add_scalar('training/precision_at_1', scores.avg, curr_iter) writer.add_scalar('training/learning_rate', scheduler.get_lr()[0], curr_iter) losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))