def checkpoint( train_ctx: Context, save_step: Optional[int] = 20, ) -> None: """Checkpoint. """ save_dir = os.path.abspath(train_ctx.save_dir) try: os.makedirs(save_dir) except OSError as exception: if exception.errno != errno.EEXIST: raise # helper function def save_current_train_ctx(save_name): save_path = os.path.join(save_dir, save_name) torch.save( dict(epoch_idx=train_ctx.epoch_idx + 1, batch_idx=train_ctx.batch_idx + 1, model=train_ctx.model.state_dict(), optimizer=train_ctx.optimizer.state_dict()), save_path) train_ctx.logger.info('checkpoint created at %s' % save_path) # checkpoint conditions if save_step > 0 and (train_ctx.epoch_idx + 1) % save_step == 0: save_current_train_ctx('model_%d.pt' % train_ctx.epoch_idx) save_current_train_ctx('model_latest.pt') if train_ctx.is_best: save_current_train_ctx('model_%s.pt' % (train_ctx.eva_metrics)) train_ctx.is_best = False
def network_trainer(data_loaders: Tuple[List[Tuple[str, data.DataLoader]], List[Tuple[str, data.DataLoader]]], model: nn.Module, criterion1: Callable[[Tensor, Tensor, Tensor, Tensor], Tensor], optimizer: Callable[[Iterable], optim.Optimizer], criterion2: Optional[Callable[ [Tensor, Tensor, Tensor, Tensor], Tensor]] = None, parameter: Optional[Callable] = None, meters: Optional[Dict[str, Callable[[Context], Any]]] = None, hooks: Optional[Dict[str, List[Callable[[Context], None]]]] = None, train_method: Optional[str] = 'sep', path_history: Optional[str] = './', resume: Optional[str] = None, log_path: Optional[str] = None, max_epoch: int = 200, test_interval: int = 1, device: str = 'cuda', use_data_parallel: bool = True, use_cudnn_benchmark: bool = True, epoch_stage1: int = 30, random_seed: int = 999) -> Context: """Network trainer. data_loaders: [train set, optional[validation set]] criterion1: loss for stage 1 criterion2: loss for stage 2 train_method: sep, train stage 1 and stage 2 seprately; e2e, train stage 1 and stage 2 joinly, i.e. in an end to end manner epoch_stage1: the number of epoches in the first stage; We use 20 for pascal and 30 for coco """ torch.manual_seed(random_seed) # setup training logger logger = logging.getLogger('nest.network_trainer') logger.handlers = [] logger.setLevel(logging.DEBUG) # log to screen screen_handler = TqdmHandler() screen_handler.setFormatter(logging.Formatter('[%(asctime)s] %(message)s')) logger.addHandler(screen_handler) # log to file if not log_path is None: # create directory first try: os.makedirs(os.path.dirname(log_path)) except OSError as exception: if exception.errno != errno.EEXIST: raise file_handler = logging.FileHandler(log_path, encoding='utf8') file_handler.setFormatter( logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s')) logger.addHandler(file_handler) # determine which progress bar to use def run_in_notebook(): try: return get_ipython().__class__.__name__.startswith('ZMQ') except NameError: pass return False progress_bar = tqdm_notebook if run_in_notebook() else tqdm # setup device device = torch.device(device) if device.type == 'cuda': assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.benchmark = use_cudnn_benchmark # loaders for train and test splits train_loaders, test_loaders = data_loaders # setup model model = model.to(device) # multi-gpu support if device.type == 'cuda' and use_data_parallel: model = nn.DataParallel(model) # setup optimizer params = model.parameters() if parameter is None else parameter(model) optimizer = optimizer(params) # resume checkpoint start_epoch_idx = 0 start_batch_idx = 0 if not resume is None: logger.info('loading checkpoint "%s"' % resume) checkpoint = torch.load(resume) start_epoch_idx = checkpoint['epoch_idx'] start_batch_idx = checkpoint['batch_idx'] model_dict = model.state_dict() trained_dict = { k: v for k, v in checkpoint['model'].items() if k in model_dict } model_dict.update(trained_dict) model.load_state_dict(model_dict) # optimizer.load_state_dict(checkpoint['optimizer']) logger.info('checkpoint loaded (epoch %d)' % start_epoch_idx) # create training context ctx = Context( split='train', is_train=True, model=model, optimizer=optimizer, max_epoch=max_epoch, epoch_idx=start_epoch_idx, batch_idx=start_batch_idx, input=Tensor(), output=Tensor(), output2=Tensor(), ##changed output3=Tensor(), ##changed target=Tensor(), target1=Tensor(), loss=Tensor(), metrics=dict(), state_dicts=[], eva_metrics='', save_dir=path_history, is_best=False, logger=logger) # helper func for executing hooks def run_hooks(hook_type): if isinstance(hooks, dict) and hook_type in hooks: if hook_type == 'save_best_model': for hook in hooks.get(hook_type): hook(ctx) else: for hook in hooks.get(hook_type): hook(ctx) # helper func for processing dataset split def process(split, data_loader, is_train): ctx.max_batch = len(data_loader) ctx.split = split ctx.is_train = is_train run_hooks('on_start_split') # set model status if is_train: model.train() gc.collect() else: model.eval() # iterate over batches for batch_idx, (input, target, _) in enumerate( progress_bar(data_loader, ascii=True, desc=split, unit='batch', leave=False)): if batch_idx < ctx.batch_idx: continue # prepare a batch of data ctx.batch_idx = batch_idx if isinstance(input, (list, tuple)): ctx.input = [ v.to(device) if torch.is_tensor(v) else v for v in input ] elif isinstance(input, dict): ctx.input = { k: v.to(device) if torch.is_tensor(v) else v for k, v in input.items() } else: ctx.input = input.to(device) ctx.target = target.to(device) run_hooks('on_start_batch') # compute output and loss with torch.set_grad_enabled(is_train): ctx.output, ctx.output2, ctx.output3 = ctx.model( ctx.input, ctx.target) ctx.loss = criterion(ctx.output, ctx.target, ctx.output2, ctx.output3) # measure performance if not meters is None: ctx.metrics.update({ split + '_' + k: v(ctx) for k, v in meters.items() if v is not None }) # update model parameters if is_train: optimizer.zero_grad() ctx.loss.backward() optimizer.step() run_hooks('on_end_batch') ctx.batch_idx = 0 run_hooks('on_end_split') def mrmse(non_zero, count_pred, count_gt): ## compute mrmse nzero_mask = torch.ones(count_gt.size()) if non_zero == 1: nzero_mask = torch.zeros(count_gt.size()) nzero_mask[count_gt != 0] = 1 mrmse = torch.pow(count_pred - count_gt, 2) mrmse = torch.mul(mrmse, nzero_mask) mrmse = torch.sum(mrmse, 0) nzero = torch.sum(nzero_mask, 0) mrmse = torch.div(mrmse, nzero) mrmse = torch.sqrt(mrmse) # print(mrmse.size()) mrmse = torch.mean(mrmse) return mrmse def rel_mrmse(non_zero, count_pred, count_gt): ## compute relative mrmse nzero_mask = torch.ones(count_gt.size()) if non_zero == 1: nzero_mask = torch.zeros(count_gt.size()) nzero_mask[count_gt != 0] = 1 num = torch.pow(count_pred - count_gt, 2) denom = count_gt.clone() denom = denom + 1 rel_mrmse = torch.div(num, denom) rel_mrmse = torch.mul(rel_mrmse, nzero_mask) rel_mrmse = torch.sum(rel_mrmse, 0) nzero = torch.sum(nzero_mask, 0) rel_mrmse = torch.div(rel_mrmse, nzero) rel_mrmse = torch.sqrt(rel_mrmse) rel_mrmse = torch.mean(rel_mrmse) return rel_mrmse # training two stages together def process2(split, data_loader, is_train, criterion): ctx.max_batch = len(data_loader) ctx.split = split ctx.is_train = is_train run_hooks('on_start_split') # set model status if is_train: model.train() gc.collect() else: model.eval() counting_pred = [] counting_gt = [] # iterate over batches for batch_idx, (input, target, target1) in enumerate( progress_bar(data_loader, ascii=True, desc=split, unit='batch', leave=False)): if batch_idx < ctx.batch_idx: continue # prepare a batch of data ctx.batch_idx = batch_idx if isinstance(input, (list, tuple)): ctx.input = [ v.to(device) if torch.is_tensor(v) else v for v in input ] elif isinstance(input, dict): ctx.input = { k: v.to(device) if torch.is_tensor(v) else v for k, v in input.items() } else: ctx.input = input.to(device) ctx.target = target.to(device) ctx.target1 = target1.to(device) run_hooks('on_start_batch') # compute output and loss with torch.set_grad_enabled(is_train): ctx.output, ctx.output2, ctx.output3 = ctx.model( ctx.input, ctx.target) if is_train: ctx.loss = criterion(ctx.output, ctx.target, ctx.output2, ctx.output3) # measure performance if not meters is None: ctx.metrics.update({ split + '_' + k: v(ctx) for k, v in meters.items() if v is not None }) # update model parameters if training the model; otherwise, calculate counting prediction if is_train: optimizer.zero_grad() ctx.loss.backward() optimizer.step() else: confidence = ctx.output class_response_map1 = ctx.output2 confidence = confidence.cpu().detach().numpy() count_one = F.adaptive_avg_pool2d( class_response_map1, 1).squeeze(2).squeeze(2).detach().cpu().numpy()[0] confidence[confidence < 0] = 0 confidence = confidence[0] confidence[confidence > 0] = 1 counting_pred.append(np.round(confidence * count_one)) counting_gt.append(target.detach().cpu().numpy()[0]) run_hooks('on_end_batch') ctx.batch_idx = 0 if not is_train: counting_pred = np.array(counting_pred) counting_gt = np.array(counting_gt) # print(counting_pred.shape,counting_gt.shape) return [ mrmse(1, torch.from_numpy(counting_pred).float(), torch.from_numpy(counting_gt).float()), rel_mrmse(1, torch.from_numpy(counting_pred).float(), torch.from_numpy(counting_gt).float()), mrmse(0, torch.from_numpy(counting_pred).float(), torch.from_numpy(counting_gt).float()), rel_mrmse(0, torch.from_numpy(counting_pred).float(), torch.from_numpy(counting_gt).float()) ] else: return None run_hooks('on_end_split') # trainer processing run_hooks('on_start') ori_lr = [] for param_group in optimizer.param_groups: ori_lr.append(param_group['lr']) history = {"best_val_epoch": [-1] * 4, "best_val_result": [np.inf] * 4} eva_metrics_list = ['mrmse_nz', 'rmrmse_nz', 'mrmse', 'rmrmse'] for epoch_idx in progress_bar(range(ctx.epoch_idx, max_epoch), ascii=True, unit='epoch'): ctx.epoch_idx = epoch_idx run_hooks('on_start_epoch') adjust_learning_rate(optimizer, epoch_idx, ori_lr) for param_group in optimizer.param_groups: print('learning rate:', param_group['lr']) if train_method == 'sep': for split, loader in train_loaders: process(split, loader, True) # testing if epoch_idx % test_interval == 0: for split, loader in test_loaders: process(split, loader, False) run_hooks('on_end_epoch') elif train_method == 'joint_train': assert not (criterion2 is None), "criterion2 not provided" ## do training: our training are divided into two stages: first stage and second stage if epoch_idx <= epoch_stage1 - 1: print('stage 1 of the training: using criterion1') for split, loader in train_loaders: process2(split, loader, True, criterion1) else: print('stage 2 of the training: using criterion2') for split, loader in train_loaders: process2(split, loader, True, criterion2) ## do validation if len(test_loaders) == 0: print("no validation during training") else: print("validation start") for split, loader in test_loaders: if epoch_idx <= epoch_stage1 - 1: results = process2(split, loader, False, criterion1) else: results = process2(split, loader, False, criterion2) print("mrmse_nz: %f, rmrmse_nz: %f, mrmse: %f, rmrmse: %f " % (results[0], results[1], results[2], results[3])) for i in range(len(results)): if history['best_val_result'][i] > results[i]: history['best_val_epoch'][i] = epoch_idx history['best_val_result'][i] = float( results[i].cpu().numpy()) ctx.eva_metrics = eva_metrics_list[i] ctx.is_best = True run_hooks('on_end_epoch') run_hooks('save_checkpoints') save_json(path_history + '/history.json', history) print() print('--------------------------------------------------------', end='\n\n') # if epoch_idx<epoch_stage1-1: # run_hooks('on_end_epoch_save_latest') # elif epoch_idx==epoch_stage1-1: # run_hooks('on_end_epoch') # else: # if len(test_loaders)==0: # run_hooks('on_end_epoch') # else: # run_hooks('on_end_epoch_save_latest') # print("mrmse: %f, rmrmse: %f, mrmse_nz: %f, rmrmse_nz: %f " %(results[0],results[1], # results[2],results[3])) # update_flag=0 # for i in range(len(results)): # if history['best_val_result'][i]>results[i]: # history['best_val_epoch'][i]=epoch_idx # history['best_val_result'][i]=float(results[i].cpu().numpy()) # ctx.eva_metrics=eva_metrics_list[i] # update_flag=1 # run_hooks('save_best_model') # save_json(path_history, history) run_hooks('on_end') return ctx