def validate(pipeline_model, data_loader, config, writer, curr_iter, best_val, best_val_iter, optimizer, epoch): val_dict = test(pipeline_model, data_loader, config) update_writer(writer, val_dict, curr_iter, 'validation') curr_val = pipeline_model.get_metric(val_dict) if curr_val > best_val: best_val = curr_val best_val_iter = curr_iter checkpoint(pipeline_model, optimizer, epoch, curr_iter, config, best_val, best_val_iter, 'best_val') logging.info( f'Current best {pipeline_model.TARGET_METRIC}: {best_val:.3f} at iter {best_val_iter}' ) # Recover back pipeline_model.train() return best_val, best_val_iter
def train(pipeline_model, data_loader, val_data_loader, config): # Set up the train flag for batch normalization pipeline_model.train() num_devices = torch.cuda.device_count() num_devices = min(config.max_ngpu, num_devices) devices = list(range(num_devices)) target_device = devices[0] pipeline_model.to(target_device) if num_devices > 1: pipeline_model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm( pipeline_model, devices) # Configuration writer = SummaryWriter(logdir=config.log_dir) data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() meters = collections.defaultdict(AverageMeter) hists = pipeline_model.initialize_hists() optimizer = pipeline_model.initialize_optimizer(config) scheduler = pipeline_model.initialize_scheduler(optimizer, config) writer = SummaryWriter(logdir=config.log_dir) # Train the network logging.info('===> Start training') best_val, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: if osp.isfile(config.resume): logging.info("=> loading checkpoint '{}'".format(config.resume)) state = torch.load(config.resume) curr_iter = state['iteration'] + 1 epoch = state['epoch'] pipeline_model.load_state_dict(state['state_dict']) if config.resume_optimizer: curr_iter = state['iteration'] + 1 scheduler = pipeline_model.initialize_scheduler( optimizer, config, last_step=curr_iter) pipeline_model.load_optimizer(optimizer, state['optimizer']) if 'best_val' in state: best_val = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( config.resume, state['epoch'])) else: logging.info("=> no checkpoint found at '{}'".format( config.resume)) data_iter = data_loader.__iter__() while is_training: for iteration in range(len(data_loader)): pipeline_model.reset_gradient(optimizer) iter_timer.tic() pipelines = parallel.replicate(pipeline_model, devices) # Get training data data_timer.tic() inputs = [] for pipeline, device in zip(pipelines, devices): with torch.cuda.device(device): while True: datum = pipeline.load_datum(data_iter, has_gt=True) num_boxes = sum(box.shape[0] for box in datum['bboxes_coords']) if config.skip_empty_boxes and num_boxes == 0: continue break inputs.append(datum) data_time_avg.update(data_timer.toc(False)) outputs = parallel.parallel_apply(pipelines, [(x, True) for x in inputs], devices=devices) losses = parallel.parallel_apply( [pipeline.loss for pipeline in pipelines], tuple(zip(inputs, outputs)), devices=devices) losses = parallel.gather(losses, target_device) losses = dict([(k, v.mean()) for k, v in losses.items()]) meters, hists = pipeline_model.update_meters(meters, hists, losses) # Compute and accumulate gradient losses['loss'].backward() # Update number of steps pipeline_model.step_optimizer(losses, optimizer, scheduler, iteration) iter_time_avg.update(iter_timer.toc(False)) if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join([ '{:.3e}'.format(x) for x in scheduler['default'].get_lr() ]) debug_str = "===> Epoch[{}]({}/{}): LR: {}\n".format( epoch, curr_iter, len(data_loader), lrs) debug_str += log_meters(meters, log_perclass_meters=False) debug_str += f"\n data time: {data_time_avg.avg:.3f}" debug_str += f" iter time: {iter_time_avg.avg:.3f}" logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs update_writer(writer, meters, curr_iter, 'training') writer.add_scalar('training/learning_rate', scheduler['default'].get_lr()[0], curr_iter) # Reset meters reset_meters(meters, hists) # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(pipeline_model, optimizer, epoch, curr_iter, config, best_val, best_val_iter) if config.heldout_save_freq > 0 and curr_iter % config.heldout_save_freq == 0: checkpoint(pipeline_model, optimizer, epoch, curr_iter, config, best_val, best_val_iter, heldout_save=True) # Validation if curr_iter % config.val_freq == 0: if num_devices > 1: unconvert_sync_batchnorm(pipeline_model) best_val, best_val_iter = validate(pipeline_model, val_data_loader, config, writer, curr_iter, best_val, best_val_iter, optimizer, epoch) if num_devices > 1: pipeline_model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm( pipeline_model, devices) if curr_iter % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() # End of iteration curr_iter += 1 epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model if num_devices > 1: unconvert_sync_batchnorm(pipeline_model) validate(pipeline_model, val_data_loader, config, writer, curr_iter, best_val, best_val_iter, optimizer, epoch) if num_devices > 1: pipeline_model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm( pipeline_model, devices) checkpoint(pipeline_model, optimizer, epoch, curr_iter, config, best_val, best_val_iter)
def train_model(args, metadata, device='cuda'): print('training on {}'.format(torch.cuda.get_device_name(device) if args.cuda else 'cpu')) # load data if not args.preload: dset = SyntheticDataset(args.file, 'cpu') # originally 'cpu' ???? train_loader = DataLoader(dset, shuffle=True, batch_size=args.batch_size) data_dim, latent_dim, aux_dim = dset.get_dims() args.N = len(dset) metadata.update(dset.get_metadata()) else: train_loader = DataLoaderGPU(args.file, shuffle=True, batch_size=args.batch_size) data_dim, latent_dim, aux_dim = train_loader.get_dims() args.N = train_loader.dataset_len metadata.update(train_loader.get_metadata()) if args.max_iter is None: args.max_iter = len(train_loader) * args.epochs if args.latent_dim is not None: latent_dim = args.latent_dim metadata.update({"train_latent_dim": latent_dim}) # define model and optimizer model = None if args.i_what == 'iVAE': model = iVAE(latent_dim, data_dim, aux_dim, n_layers=args.depth, activation='lrelu', device=device, hidden_dim=args.hidden_dim, anneal=args.anneal, # False file=metadata['file'], # Added dataset location for easier checkpoint loading seed=1, epochs=args.epochs) elif args.i_what == 'iFlow': metadata.update({"device": device}) model = iFlow(args=metadata).to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, \ factor=args.lr_drop_factor, \ patience=args.lr_patience, \ verbose=True) # factor=0.1 and patience=4 ste = time.time() print('setup time: {}s'.format(ste - st)) # setup loggers logger = Logger(logdir=LOG_FOLDER) # 'log/' exp_id = logger.get_id() # 1 tensorboard_run_name = TENSORBOARD_RUN_FOLDER + 'exp' + str(exp_id) + '_'.join( map(str, ['', args.batch_size, args.max_iter, args.lr, args.hidden_dim, args.depth, args.anneal])) # 'runs/exp1_64_12500_0.001_50_3_False' writer = SummaryWriter(logdir=tensorboard_run_name) if args.i_what == 'iFlow': logger.add('log_normalizer') logger.add('neg_log_det') logger.add('neg_trace') logger.add('loss') logger.add('perf') print('Beginning training for exp: {}'.format(exp_id)) # training loop epoch = 0 model.train() while epoch < args.epochs: # args.max_iter: #12500 est = time.time() for itr, (x, u, z) in enumerate(train_loader): acc_itr = itr + epoch * len(train_loader) # x is of shape [64, 4] # u is of shape [64, 40], one-hot coding of 40 classes # z is of shape [64, 2] # it += 1 # model.anneal(args.N, args.max_iter, it) optimizer.zero_grad() if args.cuda and not args.preload: x = x.cuda(device=device, non_blocking=True) u = u.cuda(device=device, non_blocking=True) if args.i_what == 'iVAE': elbo, z_est = model.elbo(x, u) # elbo is a scalar loss while z_est is of shape [64, 2] loss = elbo.mul(-1) elif args.i_what == 'iFlow': (log_normalizer, neg_trace, neg_log_det), z_est = model.neg_log_likelihood(x, u) loss = log_normalizer + neg_trace + neg_log_det loss.backward() optimizer.step() logger.update('loss', loss.item()) if args.i_what == 'iFlow': logger.update('log_normalizer', log_normalizer.item()) logger.update('neg_trace', neg_trace.item()) logger.update('neg_log_det', neg_log_det.item()) perf = mcc(z.cpu().numpy(), z_est.cpu().detach().numpy()) logger.update('perf', perf) if acc_itr % args.log_freq == 0: # % 25 logger.log() writer.add_scalar('data/performance', logger.get_last('perf'), acc_itr) writer.add_scalar('data/loss', logger.get_last('loss'), acc_itr) if args.i_what == 'iFlow': writer.add_scalar('data/log_normalizer', logger.get_last('log_normalizer'), acc_itr) writer.add_scalar('data/neg_trace', logger.get_last('neg_trace'), acc_itr) writer.add_scalar('data/neg_log_det', logger.get_last('neg_log_det'), acc_itr) scheduler.step(logger.get_last('loss')) if acc_itr % int(args.max_iter / 5) == 0 and not args.no_log: checkpoint(TORCH_CHECKPOINT_FOLDER, \ exp_id, \ acc_itr, \ model, \ optimizer, \ logger.get_last('loss'), \ logger.get_last('perf')) epoch += 1 eet = time.time() if args.i_what == 'iVAE': print('epoch {}: {:.4f}s;\tloss: {:.4f};\tperf: {:.4f}'.format(epoch, eet - est, logger.get_last('loss'), logger.get_last('perf'))) elif args.i_what == 'iFlow': print('epoch {}: {:.4f}s;\tloss: {:.4f} (l1: {:.4f}, l2: {:.4f}, l3: {:.4f});\tperf: {:.4f}'.format( \ epoch, eet - est, logger.get_last('loss'), logger.get_last('log_normalizer'), logger.get_last('neg_trace'), logger.get_last('neg_log_det'), logger.get_last('perf'))) et = time.time() print('training time: {}s'.format(et - ste)) # Save final model checkpoint(PT_MODELS_FOLDER, "", 'final', model, optimizer, logger.get_last('loss'), logger.get_last('perf')) writer.close() if not args.no_log: logger.add_metadata(**metadata) logger.save_to_json() logger.save_to_npz() print('total time: {}s'.format(et - st)) return model
writer.add_scalar('data/log_normalizer', logger.get_last('log_normalizer'), acc_itr) writer.add_scalar('data/neg_trace', logger.get_last('neg_trace'), acc_itr) writer.add_scalar('data/neg_log_det', logger.get_last('neg_log_det'), acc_itr) scheduler.step(logger.get_last('loss')) #scheduler.step(-perf) if acc_itr % int(args.max_iter / 5) == 0 and not args.no_log: checkpoint(TORCH_CHECKPOINT_FOLDER, \ exp_id, \ acc_itr, \ model, \ optimizer, \ logger.get_last('loss'), \ logger.get_last('perf')) """ if args.i_what == 'iVAE': print('----epoch {} iter {}:\tloss: {:.4f};\tperf: {:.4f}'.format(\ epoch, \ itr, \ loss.item(), \ perf)) elif args.i_what == 'iFlow': print('----epoch {} iter {}:\tloss: {:.4f} (l1: {:.4f}, l2: {:.4f}, l3: {:.4f});\tperf: {:.4f}'.format(\ epoch, \ itr, \ loss.item(), \
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): device = config.device_id distributed = get_world_size() > 1 # Set up the train flag for batch normalization model.train() # Configuration writer = SummaryWriter(log_dir=config.log_dir) data_timer, iter_timer = Timer(), Timer() fw_timer, bw_timer, ddp_timer = Timer(), Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() fw_time_avg, bw_time_avg, ddp_time_avg = AverageMeter(), AverageMeter( ), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) writer = SummaryWriter(log_dir=config.log_dir) # Train the network logging.info('===> Start training on {} GPUs, batch-size={}'.format( get_world_size(), config.batch_size * get_world_size())) best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load( checkpoint_fn, map_location=lambda s, l: default_restore_location(s, 'cpu')) curr_iter = state['iteration'] + 1 epoch = state['epoch'] load_state(model, state['state_dict']) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() # (distributed) infinite sampler while is_training: for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss, batch_score = 0, 0, 0 iter_timer.tic() # set random seed for every iteration for trackability _set_seed(config, curr_iter) for sub_iter in range(config.iter_size): # Get training data data_timer.tic() coords, input, target = data_iter.next() # For some networks, making the network invariant to even, odd coords is important coords[:, :3] += (torch.rand(3) * 100).type_as(coords) # Preprocess input color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) data_time += data_timer.toc(False) # Feed forward fw_timer.tic() inputs = (sinput, ) if config.wrapper_type == 'None' else ( sinput, coords, color) # model.initialize_coords(*init_args) soutput = model(*inputs) # The output of the network is not sorted target = target.long().to(device) loss = criterion(soutput.F, target.long()) # Compute and accumulate gradient loss /= config.iter_size pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target) fw_timer.toc(False) bw_timer.tic() # bp the loss loss.backward() bw_timer.toc(False) # gather information logging_output = { 'loss': loss.item(), 'score': score / config.iter_size } ddp_timer.tic() if distributed: logging_output = all_gather_list(logging_output) logging_output = { w: np.mean([a[w] for a in logging_output]) for w in logging_output[0] } batch_loss += logging_output['loss'] batch_score += logging_output['score'] ddp_timer.toc(False) # Update number of steps optimizer.step() scheduler.step() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) fw_time_avg.update(fw_timer.diff) bw_time_avg.update(bw_timer.diff) ddp_time_avg.update(ddp_timer.diff) losses.update(batch_loss, target.size(0)) scores.update(batch_score, target.size(0)) if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Forward time: {:.4f}, Backward time: {:.4f}, DDP time: {:.4f}, Total iter time: {:.4f}".format( scores.avg, data_time_avg.avg, fw_time_avg.avg, bw_time_avg.avg, ddp_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs writer.add_scalar('training/loss', losses.avg, curr_iter) writer.add_scalar('training/precision_at_1', scores.avg, curr_iter) writer.add_scalar('training/learning_rate', scheduler.get_lr()[0], curr_iter) losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() if curr_iter % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() # End of iteration curr_iter += 1 epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() regs, losses, scores = AverageMeter(), AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: # Test loaded ckpt first v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config) checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] # we skip attention maps because the shape won't match because voxel number is different # e.g. copyting a param with shape (23385, 8, 4) to (43529, 8, 4) d = { k: v for k, v in state['state_dict'].items() if 'map' not in k } # handle those attn maps we don't load from saved dict for k in model.state_dict().keys(): if k in d.keys(): continue d[k] = model.state_dict()[k] model.load_state_dict(d) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() if config.dataset == "SemanticKITTI": num_class = 19 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 10 elif config.dataset == "S3DIS": num_class = 13 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq elif config.dataset == "Nuscenes": num_class = 16 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 50 else: num_class = 20 val_freq_ = config.val_freq while is_training: total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() if curr_iter >= config.max_iter: # if curr_iter >= max(config.max_iter, config.epochs*(len(data_loader) // config.iter_size): is_training = False break elif curr_iter >= config.max_iter * (2 / 3): config.val_freq = val_freq_ * 2 # valid more freq on lower half for sub_iter in range(config.iter_size): # Get training data data_timer.tic() pointcloud = None if config.return_transformation: coords, input, target, _, _, pointcloud, transformation, _ = data_iter.next( ) else: coords, input, target, _, _, _ = data_iter.next( ) # ignore unique_map and inverse_map if config.use_aux: assert target.shape[1] == 2 aux = target[:, 1] target = target[:, 0] else: aux = None # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / input[:, :3].max() - 0.5 coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5 # cat xyz into the rgb feature if config.xyz_input: input = torch.cat([coords_norm, input], dim=1) sinput = SparseTensor(input, coords, device=device) starget = SparseTensor( target.unsqueeze(-1).float(), coordinate_map_key=sinput.coordinate_map_key, coordinate_manager=sinput.coordinate_manager, device=device ) # must share the same coord-manager to align for sinput data_time += data_timer.toc(False) # model.initialize_coords(*init_args) # d = {} # d['c'] = sinput.C # d['l'] = starget.F # torch.save('./plot/test-label.pth') # import ipdb; ipdb.set_trace() # Set up profiler # memory_profiler = CUDAMemoryProfiler( # [model, criterion], # filename="cuda_memory.profile" # ) # sys.settrace(memory_profiler) # threading.settrace(memory_profiler) # with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=True) as prof0: if aux is not None: soutput = model(sinput, aux) elif config.enable_point_branch: soutput = model(sinput, iter_=curr_iter / config.max_iter, enable_point_branch=True) else: # label-aux, feed it in as additional reg soutput = model( sinput, iter_=curr_iter / config.max_iter, aux=starget ) # feed in the progress of training for annealing inside the model # The output of the network is not sorted target = target.view(-1).long().to(device) loss = criterion(soutput.F, target.long()) # ====== other loss regs ===== if hasattr(model, 'block1'): cur_loss = torch.tensor([0.], device=device) if hasattr(model.block1[0], 'vq_loss'): if model.block1[0].vq_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].vq_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur vq_loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.block1[0], 'diverse_loss'): if model.block1[0].diverse_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].diverse_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur diverse _loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.block1[0], 'label_reg'): if model.block1[0].label_reg is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].label_reg # m is the nn.Sequential obj, m[0] is the TRBlock # logging.info('Cur Loss: {}, Cur diverse _loss: {}'.format(loss, cur_loss)) loss += cur_loss # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # soutput = model(sinput) # Update number of steps if not config.use_sam: optimizer.step() else: optimizer.first_step(zero_grad=True) soutput = model(sinput, iter_=curr_iter / config.max_iter, aux=starget) criterion(soutput.F, target.long()).backward() optimizer.second_step(zero_grad=True) if config.lr_warmup is None: scheduler.step() else: if curr_iter >= config.lr_warmup: scheduler.step() for g in optimizer.param_groups: g['lr'] = config.lr * (iteration + 1) / config.lr_warmup # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target, ignore_label=-1) regs.update(cur_loss.item(), target.size(0)) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # calc the train-iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target != -1)) | (target == l)).sum() if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) IoU = ((total_correct_class) / (total_iou_deno_class + 1e-6)).mean() * 100. debug_str = "[{}] ===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( config.log_dir.split('/')[-2], epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tIoU {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, IoU.item(), data_time_avg.avg, iter_time_avg.avg) if regs.avg > 0: debug_str += "\n Additional Reg Loss {:.3f}".format( regs.avg) # print(debug_str) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, None, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val", save_inter=True) logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # print("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) v_loss, v_score, v_mAP, val_miou = test(model, val_data_loader, config) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train_distill(model, data_loader, val_data_loader, config, transform_data_fn=None): ''' the distillation training some cfgs here ''' # distill_lambda = 1 # distill_lambda = 0.33 distill_lambda = 0.67 # TWO_STAGE=True: Transformer is first trained with L2 loss to match ResNet's activation, and then it fintunes like normal training on the second stage. # TWO_STAGE=False: Transformer trains with combined loss TWO_STAGE = False # STAGE_PERCENTAGE = 0.7 device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True # TODO: # load the sub-model only # FIXME: some dirty hard-written stuff, only supporting current state tch_model_cls = load_model('Res16UNet18A') tch_model = tch_model_cls(3, 20, config).to(device) # checkpoint_fn = "/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/outputs/ScannetSparseVoxelizationDataset/Res16UNet18A/resnet_base/weights.pth" checkpoint_fn = "/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/outputs/ScannetSparseVoxelizationDataset/Res16UNet18A/Res18A/weights.pth" # voxel-size: 0.05 assert osp.isfile(checkpoint_fn) logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) d = {k: v for k, v in state['state_dict'].items() if 'map' not in k} tch_model.load_state_dict(d) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) if config.resume: raise NotImplementedError # Test loaded ckpt first # checkpoint_fn = config.resume + '/weights.pth' # if osp.isfile(checkpoint_fn): # logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) # state = torch.load(checkpoint_fn) # curr_iter = state['iteration'] + 1 # epoch = state['epoch'] # d = {k:v for k,v in state['state_dict'].items() if 'map' not in k } # model.load_state_dict(d) # if config.resume_optimizer: # scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) # optimizer.load_state_dict(state['optimizer']) # if 'best_val' in state: # best_val_miou = state['best_val'] # best_val_iter = state['best_val_iter'] # logging.info("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_fn, state['epoch'])) # else: # raise ValueError("=> no checkpoint found at '{}'".format(checkpoint_fn)) # test after loading the ckpt v_loss, v_score, v_mAP, v_mIoU = test(tch_model, val_data_loader, config) logging.info('Tch model tested, bes_miou: {}'.format(v_mIoU)) data_iter = data_loader.__iter__() while is_training: num_class = 20 total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) total_iteration = len(data_loader) // config.iter_size for iteration in range(total_iteration): # NOTE: for single stage distillation, L2 loss might be too large at first # so we added a warmup training that don't use L2 loss if iteration < 0: use_distill = False else: use_distill = True # Stage 1 / Stage 2 boundary if TWO_STAGE: stage_boundary = int(total_iteration * STAGE_PERCENTAGE) optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): # Get training data data_timer.tic() if config.return_transformation: coords, input, target, _, _, pointcloud, transformation = data_iter.next( ) else: coords, input, target, _, _ = data_iter.next( ) # ignore unique_map and inverse_map if config.use_aux: assert target.shape[1] == 2 aux = target[:, 1] target = target[:, 0] else: aux = None # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5 # cat xyz into the rgb feature if config.xyz_input: input = torch.cat([coords_norm, input], dim=1) sinput = SparseTensor(input, coords, device=device) # TODO: return both-models # in order to not breaking the valid interface, use a get_loss to get the regsitered loss data_time += data_timer.toc(False) # model.initialize_coords(*init_args) if aux is not None: raise NotImplementedError # flatten ground truth tensor target = target.view(-1).long().to(device) if TWO_STAGE: if iteration < stage_boundary: # Stage 1: train transformer on L2 loss soutput, anchor = model(sinput, save_anchor=True) # Make sure gradient don't flow to teacher model with torch.no_grad(): _, tch_anchor = tch_model(sinput, save_anchor=True) loss = DistillLoss(tch_anchor, anchor) else: # Stage 2: finetune transformer on Cross-Entropy soutput = model(sinput) loss = criterion(soutput.F, target.long()) else: if use_distill: # after warm up soutput, anchor = model(sinput, save_anchor=True) # if pretrained teacher, do not let the grad flow to teacher to update its params with torch.no_grad(): tch_soutput, tch_anchor = tch_model( sinput, save_anchor=True) else: # warming up soutput = model(sinput) # The output of the network is not sorted loss = criterion(soutput.F, target.long()) # Add L2 loss if use distillation if use_distill: distill_loss = DistillLoss(tch_anchor, anchor) * distill_lambda loss += distill_loss # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # Update number of steps optimizer.step() scheduler.step() # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target, ignore_label=-1) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # calc the train-iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target != -1)) | (target == l)).sum() if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "[{}] ===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( config.log_dir, epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) if use_distill and not TWO_STAGE: logging.info('Loss {} Distill Loss:{}'.format( loss, distill_loss)) # Reset timers data_time_avg.reset() iter_time_avg.reset() losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, None, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val", save_inter=True) logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) v_loss, v_score, v_mAP, val_miou = test(model, val_data_loader, config) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train_point(model, data_loader, val_data_loader, config, transform_data_fn=None): device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=-1) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] d = { k: v for k, v in state['state_dict'].items() if 'map' not in k } model.load_state_dict(d) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() while is_training: num_class = 20 total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): # Get training data data = data_iter.next() points, target, sample_weight = data if config.pure_point: sinput = points.transpose(1, 2).cuda().float() # DEBUG: use the discrete coord for point-based ''' feats = torch.unbind(points[:,:,:], dim=0) voxel_size = config.voxel_size coords = torch.unbind(points[:,:,:3]/voxel_size, dim=0) # 0.05 is the voxel-size coords, feats= ME.utils.sparse_collate(coords, feats) # assert feats.reshape([16, 4096, -1]) == points[:,:,3:] points_ = ME.TensorField(features=feats.float(), coordinates=coords, device=device) tmp_voxel = points_.sparse() sinput_ = tmp_voxel.slice(points_) sinput = torch.cat([sinput_.C[:,1:]*config.voxel_size, sinput_.F[:,3:]],dim=1).reshape([config.batch_size, config.num_points, 6]) # sinput = sinput_.F.reshape([config.batch_size, config.num_points, 6]) sinput = sinput.transpose(1,2).cuda().float() # sinput = torch.cat([coords[:,1:], feats],dim=1).reshape([config.batch_size, config.num_points, 6]) # sinput = sinput.transpose(1,2).cuda().float() ''' # For some networks, making the network invariant to even, odd coords is important # coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input # if config.normalize_color: # feats = feats / 255. - 0.5 # torch.save(points[:,:,:3], './sandbox/tensorfield-c.pth') # torch.save(points_.C, './sandbox/points-c.pth') else: # feats = torch.unbind(points[:,:,3:], dim=0) # WRONG: should also feed in xyz as inupt feature voxel_size = config.voxel_size coords = torch.unbind(points[:, :, :3] / voxel_size, dim=0) # 0.05 is the voxel-size # Normalize the xyz in feature # points[:,:,:3] = points[:,:,:3] / points[:,:,:3].mean() feats = torch.unbind(points[:, :, :], dim=0) coords, feats = ME.utils.sparse_collate(coords, feats) # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input # if config.normalize_color: # feats = feats / 255. - 0.5 # they are the same points_ = ME.TensorField(features=feats.float(), coordinates=coords, device=device) # points_1 = ME.TensorField(features=feats.float(), coordinates=coords, device=device, quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE) # points_2 = ME.TensorField(features=feats.float(), coordinates=coords, device=device, quantization_mode=ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE) sinput = points_.sparse() data_time += data_timer.toc(False) B, npoint = target.shape # model.initialize_coords(*init_args) soutput = model(sinput) if config.pure_point: soutput = soutput.reshape([B * npoint, -1]) else: soutput = soutput.slice(points_).F # s1 = soutput.slice(points_) # print(soutput.quantization_mode) # soutput.quantization_mode = ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE # s2 = soutput.slice(points_) # The output of the network is not sorted target = (target - 1).view(-1).long().to(device) # catch NAN if torch.isnan(soutput).sum() > 0: import ipdb ipdb.set_trace() loss = criterion(soutput, target) if torch.isnan(loss).sum() > 0: import ipdb ipdb.set_trace() loss = (loss * sample_weight.to(device)).mean() # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # print(model.input_mlp[0].weight.max()) # print(model.input_mlp[0].weight.grad.max()) # Update number of steps optimizer.step() scheduler.step() # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput, target) score = precision_at_one(pred, target, ignore_label=-1) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # Calc the iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target >= 0)) | (target == l)).sum() if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation: # for point-based should use alternate dataloader for eval # if curr_iter % config.val_freq == 0: # val_miou = test_points(model, val_data_loader, None, curr_iter, config, transform_data_fn) # if val_miou > best_val_miou: # best_val_miou = val_miou # best_val_iter = curr_iter # checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, # "best_val") # logging.info("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter)) # # Recover back # model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) test_points(model, val_data_loader, config) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): all_losses = [] device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration writer = SummaryWriter(log_dir=config.log_dir) data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores, batch_losses = AverageMeter(), AverageMeter(), {} optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) alpha, gamma, eps = 1, 2, 1e-6 writer = SummaryWriter(log_dir=config.log_dir) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] model.load_state_dict(state['state_dict']) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() while is_training: print( "********************************** epoch N° {0} ************************" .format(epoch)) for iteration in range(len(data_loader) // config.iter_size): print("####### Iteration N° {0}".format(iteration)) optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): print("------------- Sub_iteration N° {0}".format(sub_iter)) # Get training data data_timer.tic() coords, input, target = data_iter.next() print("len of coords : {0}".format(len(coords))) # For some networks, making the network invariant to even, odd coords is important coords[:, :3] += (torch.rand(3) * 100).type_as(coords) # Preprocess input color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) data_time += data_timer.toc(False) # Feed forward inputs = (sinput, ) if config.wrapper_type == 'None' else ( sinput, coords, color) # model.initialize_coords(*init_args) soutput = model(*inputs) # The output of the network is not sorted target = target.long().to(device) print("count of classes : {0}".format( np.unique(target.cpu().numpy(), return_counts=True))) print("target : {0}\ntarget_len : {1}".format( target, len(target))) print("target [0]: {0}".format(target[0])) input_soft = nn.functional.softmax(soutput.F, dim=1) + eps print("input_soft[0] : {0}".format(input_soft[0])) focal_weight = torch.pow(-input_soft + 1., gamma) print("focal_weight : {0}\nweight[0] : {1}".format( focal_weight, focal_weight[0])) focal_loss = (-alpha * focal_weight * torch.log(input_soft)).mean() loss = criterion(soutput.F, target.long()) print("focal_loss :{0}\nloss : {1}".format(focal_loss, loss)) # Compute and accumulate gradient loss /= config.iter_size #batch_loss += loss batch_loss += loss.item() print("batch_loss : {0}".format(batch_loss)) loss.backward() # Update number of steps optimizer.step() scheduler.step() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Total iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs writer.add_scalar('training/loss', losses.avg, curr_iter) writer.add_scalar('training/precision_at_1', scores.avg, curr_iter) writer.add_scalar('training/learning_rate', scheduler.get_lr()[0], curr_iter) losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) # Validation if curr_iter % config.val_freq == 0: val_miou, val_losses = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn, epoch) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() if curr_iter % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() batch_losses[epoch] = batch_loss # End of iteration curr_iter += 1 with open(config.log_dir + "/train_loss.txt", 'a') as train_loss_log: train_loss_log.writelines('{0}, {1}\n'.format( batch_losses[epoch], epoch)) train_loss_log.close() epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn, epoch)[0] if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train_worker(gpu, num_devices, NetClass, data_loader, val_data_loader, config, transform_data_fn=None): if gpu is not None: print("Use GPU: {} for training".format(gpu)) rank = gpu addr = 23491 dist.init_process_group(backend="nccl", init_method="tcp://127.0.0.1:{}".format(addr), world_size=num_devices, rank=rank) # replace with DistributedSampler if config.multiprocess: from lib.dataloader_dist import InfSampler sampler = InfSampler(data_loader.dataset) data_loader = DataLoader(dataset=data_loader.dataset, num_workers=data_loader.num_workers, batch_size=data_loader.batch_size, collate_fn=data_loader.collate_fn, worker_init_fn=data_loader.worker_init_fn, sampler=sampler) if data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 num_labels = data_loader.dataset.NUM_LABELS # load model if config.pure_point: model = NetClass(num_class=config.num_labels, N=config.num_points, normal_channel=config.num_in_channel) else: if config.model == 'MixedTransformer': model = NetClass(config, num_class=num_labels, N=config.num_points, normal_channel=num_in_channel) elif config.model == 'MinkowskiVoxelTransformer': model = NetClass(config, num_in_channel, num_labels) elif config.model == 'MinkowskiTransformerNet': model = NetClass(config, num_in_channel, num_labels) elif "Res" in config.model: model = NetClass(num_in_channel, num_labels, config) else: model = NetClass(num_in_channel, num_labels, config) if config.weights == 'modelzoo': model.preload_modelzoo() elif config.weights.lower() != 'none': state = torch.load(config.weights) # delete the keys containing the attn since it raises size mismatch d = {k: v for k, v in state['state' '_dict'].items() if 'map' not in k} if config.weights_for_inner_model: model.model.load_state_dict(d) else: if config.lenient_weight_loading: matched_weights = load_state_with_same_shape( model, state['state_dict']) model_dict = model.state_dict() model_dict.update(matched_weights) model.load_state_dict(model_dict) else: model.load_state_dict(d, strict=False) torch.cuda.set_device(gpu) model.cuda(gpu) # use model with DDP model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[gpu], find_unused_parameters=False) # Synchronized batch norm model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() regs, losses, scores = AverageMeter(), AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) # Train the network if rank == 0: setup_logger(config) logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: # Test loaded ckpt first v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config) checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] # we skip attention maps because the shape won't match because voxel number is different # e.g. copyting a param with shape (23385, 8, 4) to (43529, 8, 4) d = { k: v for k, v in state['state_dict'].items() if 'map' not in k } # handle those attn maps we don't load from saved dict for k in model.state_dict().keys(): if k in d.keys(): continue d[k] = model.state_dict()[k] model.load_state_dict(d) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() device = gpu # multitrain fed in the device if config.dataset == "SemanticKITTI": num_class = 19 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 10 # origianl val_freq_ elif config.dataset == 'S3DIS': num_class = 13 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq elif config.dataset == "Nuscenes": num_class = 16 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 50 else: val_freq_ = config.val_freq num_class = 20 while is_training: total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() if curr_iter >= config.max_iter: # if curr_iter >= max(config.max_iter, config.epochs*(len(data_loader) // config.iter_size): is_training = False break elif curr_iter >= config.max_iter * (2 / 3): config.val_freq = val_freq_ * 2 # valid more freq on lower half for sub_iter in range(config.iter_size): # Get training data data_timer.tic() if config.return_transformation: coords, input, target, _, _, pointcloud, transformation = data_iter.next( ) else: coords, input, target, _, _ = data_iter.next( ) # ignore unique_map and inverse_map if config.use_aux: assert target.shape[1] == 2 aux = target[:, 1] target = target[:, 0] else: aux = None # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / input[:, :3].max() - 0.5 coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5 # cat xyz into the rgb feature if config.xyz_input: input = torch.cat([coords_norm, input], dim=1) # print(device) sinput = SparseTensor(input, coords, device=device) # d = {} # d['coord'] = sinput.C # d['feat'] = sinput.F # torch.save(d, 'voxel.pth') # import ipdb; ipdb.set_trace() data_time += data_timer.toc(False) # model.initialize_coords(*init_args) if aux is not None: soutput = model(sinput, aux) elif config.enable_point_branch: soutput = model(sinput, iter_=curr_iter / config.max_iter, enable_point_branch=True) else: soutput = model( sinput, iter_=curr_iter / config.max_iter ) # feed in the progress of training for annealing inside the model # soutput = model(sinput) # The output of the network is not sorted target = target.view(-1).long().to(device) loss = criterion(soutput.F, target.long()) # ====== other loss regs ===== cur_loss = torch.tensor([0.], device=device) if hasattr(model, 'module.block1'): cur_loss = torch.tensor([0.], device=device) if hasattr(model.module.block1[0], 'vq_loss'): if model.block1[0].vq_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].vq_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur vq_loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.module.block1[0], 'diverse_loss'): if model.block1[0].diverse_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].diverse_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur diverse _loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.module.block1[0], 'label_reg'): if model.block1[0].label_reg is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].label_reg # m is the nn.Sequential obj, m[0] is the TRBlock # logging.info('Cur Loss: {}, Cur diverse _loss: {}'.format(loss, cur_loss)) loss += cur_loss # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() if not config.use_sam: loss.backward() else: with model.no_sync(): loss.backward() # Update number of steps if not config.use_sam: optimizer.step() else: optimizer.first_step(zero_grad=True) soutput = model(sinput, iter_=curr_iter / config.max_iter, aux=starget) criterion(soutput.F, target.long()).backward() optimizer.second_step(zero_grad=True) if config.lr_warmup is None: scheduler.step() else: if curr_iter >= config.lr_warmup: scheduler.step() else: for g in optimizer.param_groups: g['lr'] = config.lr * (iteration + 1) / config.lr_warmup # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target, ignore_label=-1) regs.update(cur_loss.item(), target.size(0)) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # calc the train-iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target != -1)) | (target == l)).sum() if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(g['lr']) for g in optimizer.param_groups]) IoU = ((total_correct_class) / (total_iou_deno_class + 1e-6)).mean() * 100. debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tIoU {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, IoU.item(), data_time_avg.avg, iter_time_avg.avg) if regs.avg > 0: debug_str += "\n Additional Reg Loss {:.3f}".format( regs.avg) if rank == 0: logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs losses.reset() scores.reset() # only save status on the 1st gpu if rank == 0: # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, None, curr_iter, config, transform_data_fn ) # feedin None for SummaryWriter args if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val", save_inter=True) if rank == 0: logging.info( "Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) if rank == 0: logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model if rank == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) v_loss, v_score, v_mAP, val_mIoU = test(model, val_data_loader, config) if val_miou > best_val_miou and rank == 0: best_val_miou = val_miou best_val_iter = curr_iter logging.info("Final best miou: {} at iter {} ".format( val_miou, curr_iter)) checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration writer = SummaryWriter(log_dir=config.log_dir) data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) writer = SummaryWriter(log_dir=config.log_dir) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] model.load_state_dict(state['state_dict']) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() while is_training: for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): # Get training data data_timer.tic() if config.return_transformation: coords, input, target, pointcloud, transformation = data_iter.next( ) else: coords, input, target = data_iter.next() # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) data_time += data_timer.toc(False) # model.initialize_coords(*init_args) soutput = model(sinput) # The output of the network is not sorted target = target.long().to(device) loss = criterion(soutput.F, target.long()) # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # Update number of steps optimizer.step() scheduler.step() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs writer.add_scalar('training/loss', losses.avg, curr_iter) writer.add_scalar('training/precision_at_1', scores.avg, curr_iter) writer.add_scalar('training/learning_rate', scheduler.get_lr()[0], curr_iter) losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
logger.update('elbo', -elbo.item()) perf = mcc(z.cpu().numpy(), z_est.cpu().detach().numpy()) logger.update('perf', perf) if it % args.log_freq == 0: logger.log() writer.add_scalar('data/performance', logger.get_last('perf'), it) writer.add_scalar('data/elbo', logger.get_last('elbo'), it) scheduler.step(logger.get_last('elbo')) if it % int(args.max_iter / 5) == 0 and not args.no_log: checkpoint(TORCH_CHECKPOINT_FOLDER, exp_id, it, model, optimizer, logger.get_last('elbo'), logger.get_last('perf')) eet = time.time() print('epoch {} done in: {}s;\tloss: {};\tperf: {}'.format( int(it / len(train_loader)) + 1, eet - est, logger.get_last('elbo'), logger.get_last('perf'))) et = time.time() print('training time: {}s'.format(et - ste)) writer.close() if not args.no_log: logger.add_metadata(**metadata) logger.save_to_json() logger.save_to_npz()