def do_test(cfg, model, dataloader, logger, task, load_ckpt): if load_ckpt is not None: checkpointer = Checkpointer(model, save_dir=cfg.OUTPUT_DIR, logger=logger, monitor_unit='episode') checkpointer.load_checkpoint(load_ckpt) val_metrics = model.metric_evaluator model.eval() num_images = 0 meters = MetricLogger(delimiter=" ") logger.info('Start testing...') start_testing_time = time.time() end = time.time() for iteration, data in enumerate(dataloader): data_time = time.time() - end inputs, labels = torch.cat( data['images']).to(device), data['labels'].to(device) logits = model(inputs) val_metrics.accumulated_update(logits, labels) num_images += logits.shape[0] batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (len(dataloader) - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % 50 == 0 and iteration > 0: logger.info('eta: {}, iter: {}/{}'.format(eta_string, iteration, len(dataloader))) val_metrics.gather_results() logger.info('num of images: {}'.format(num_images)) logger.info('{} top1 acc: {:.4f}'.format( task, val_metrics.accumulated_topk_corrects['top1_acc'])) total = time.time() - start_testing_time total_time_str = str(datetime.timedelta(seconds=total)) logger.info("Total testing time: {}".format(total_time_str)) return val_metrics
def do_test(cfg, model, dataloader, logger, task, load_ckpt): checkpointer = Checkpointer(model, save_dir=cfg.OUTPUT_DIR, logger=logger) checkpointer.load_checkpoint(load_ckpt) val_metrics = TransferNetMetrics(cfg) model.eval() num_images = 0 logger.info('Start testing...') for iteration, data in enumerate(dataloader): inputs, targets = data['image'].to(device), data['label_articleType'].to(device) cls_scores = model(inputs, targets) val_metrics.accumulated_update(cls_scores, targets) num_images += len(targets) val_metrics.gather_results() logger.info('num of images: {}'.format(num_images)) logger.info('{} top1/5 acc: {:.4f}/{:.4f}, mean class acc: {:.4f}'.format(task, val_metrics.accumulated_topk_corrects[ 'top1_acc'], val_metrics.accumulated_topk_corrects[ 'top5_acc'], val_metrics.mean_class_accuracy)) return val_metrics
def do_train(cfg, model, train_dataloader, logger, load_ckpt=None): # define optimizer if cfg.TRAIN.OPTIMIZER == 'sgd': optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR_BASE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) else: raise NotImplementedError # define learning rate scheduler lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, cfg.TRAIN.LR_DECAY) checkpointer = Checkpointer(model, optimizer, lr_scheduler, cfg.OUTPUT_DIR, logger, monitor_unit='episode') training_args = {} # training_args['iteration'] = 1 training_args['episode'] = 1 if load_ckpt: checkpointer.load_checkpoint(load_ckpt, strict=False) if checkpointer.has_checkpoint(): extra_checkpoint_data = checkpointer.load() training_args.update(extra_checkpoint_data) start_episode = training_args['episode'] episode = training_args['episode'] meters = MetricLogger(delimiter=" ") end = time.time() start_training_time = time.time() model.train() break_while = False while not break_while: for inner_iter, data in enumerate(train_dataloader): training_args['episode'] = episode data_time = time.time() - end # targets = torch.cat(data['labels']).to(device) inputs = torch.cat(data['images']).to(device) logits = model(inputs) losses = model.loss_evaluator(logits) metrics = model.metric_evaluator(logits) total_loss = sum(loss for loss in losses.values()) meters.update(loss=total_loss, **losses, **metrics) optimizer.zero_grad() total_loss.backward() optimizer.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) if inner_iter % cfg.TRAIN.PRINT_PERIOD == 0: eta_seconds = meters.time.global_avg * (cfg.TRAIN.MAX_EPISODE - episode) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) logger.info( meters.delimiter.join([ "eta: {eta}", "episode: {ep}/{max_ep}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ]).format( eta=eta_string, ep=episode, max_ep=cfg.TRAIN.MAX_EPISODE, iter=inner_iter, max_iter=len(train_dataloader), meters=str(meters), lr=optimizer.param_groups[-1]["lr"], memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else 0., )) if episode % cfg.TRAIN.LR_DECAY_EPISODE == 0: logger.info("lr decayed to {:.4f}".format( optimizer.param_groups[-1]["lr"])) lr_scheduler.step() if episode == cfg.TRAIN.MAX_EPISODE: break_while = True checkpointer.save("model_{:06d}".format(episode), **training_args) break if episode % cfg.TRAIN.SAVE_CKPT_EPISODE == 0: checkpointer.save("model_{:06d}".format(episode), **training_args) episode += 1 total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / epoch)".format( total_time_str, total_training_time / (episode - start_episode if episode > start_episode else 1)))
def do_train(cfg, model, train_dataloader, val_dataloader, logger, load_ckpt=None): # define optimizer if cfg.TRAIN.OPTIMIZER == 'sgd': optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR_BASE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) else: raise NotImplementedError # define learning rate scheduler lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, cfg.TRAIN.LR_DECAY) checkpointer = Checkpointer(model, optimizer, lr_scheduler, cfg.OUTPUT_DIR, logger) training_args = {} # training_args['iteration'] = 1 training_args['epoch'] = 1 training_args['val_best'] = 0. if load_ckpt: checkpointer.load_checkpoint(load_ckpt, strict=False) if checkpointer.has_checkpoint(): extra_checkpoint_data = checkpointer.load() training_args.update(extra_checkpoint_data) # start_iter = training_args['iteration'] start_epoch = training_args['epoch'] checkpointer.current_val_best = training_args['val_best'] meters = MetricLogger(delimiter=" ") end = time.time() start_training_time = time.time() for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH + 1): training_args['epoch'] = epoch model.train() for inner_iter, data in enumerate(train_dataloader): # training_args['iteration'] = iteration # logger.info('inner_iter: {}, label: {}'.format(inner_iter, data['label_articleType'], len(data['label_articleType']))) data_time = time.time() - end inputs, targets = data['image'].to(device), data['label_articleType'].to(device) cls_scores = model(inputs, targets) losses = model.loss_evaluator(cls_scores, targets) metrics = model.metric_evaluator(cls_scores, targets) total_loss = sum(loss for loss in losses.values()) meters.update(loss=total_loss, **losses, **metrics) optimizer.zero_grad() total_loss.backward() optimizer.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) if inner_iter % cfg.TRAIN.PRINT_PERIOD == 0: eta_seconds = meters.time.global_avg * (len(train_dataloader) * cfg.TRAIN.MAX_EPOCH - (epoch - 1) * len(train_dataloader) - inner_iter) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) logger.info( meters.delimiter.join( [ "eta: {eta}", "epoch: {ep}/{max_ep} (iter: {iter}/{max_iter})", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ] ).format( eta=eta_string, ep=epoch, max_ep=cfg.TRAIN.MAX_EPOCH, iter=inner_iter, max_iter=len(train_dataloader), meters=str(meters), lr=optimizer.param_groups[-1]["lr"], memory=( torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else 0., ) ) if epoch % cfg.TRAIN.VAL_EPOCH == 0: logger.info('start evaluating at epoch {}'.format(epoch)) val_metrics = do_eval(cfg, model, val_dataloader, logger, 'validation') if val_metrics.mean_class_accuracy > checkpointer.current_val_best: checkpointer.current_val_best = val_metrics.mean_class_accuracy training_args['val_best'] = checkpointer.current_val_best checkpointer.save("model_{:04d}_val_{:.4f}".format(epoch, checkpointer.current_val_best), **training_args) checkpointer.patience = 0 else: checkpointer.patience += 1 logger.info('current patience: {}/{}'.format(checkpointer.patience, cfg.TRAIN.PATIENCE)) if epoch == cfg.TRAIN.MAX_EPOCH or epoch % cfg.TRAIN.SAVE_CKPT_EPOCH == 0 or checkpointer.patience == cfg.TRAIN.PATIENCE: checkpointer.save("model_{:04d}".format(epoch), **training_args) if checkpointer.patience == cfg.TRAIN.PATIENCE: logger.info('Max patience triggered. Early terminate training') break if epoch % cfg.TRAIN.LR_DECAY_EPOCH == 0: logger.info("lr decayed to {:.4f}".format(optimizer.param_groups[-1]["lr"])) lr_scheduler.step() total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info( "Total training time: {} ({:.4f} s / epoch)".format( total_time_str, total_training_time / (epoch - start_epoch if epoch > start_epoch else 1) ) )