def train(model, epoch, max_epoch, data_loader, optimizer, criterion, device, print_iter_period, logger, tb_writer): meters = MetricLogger() max_iter = len(data_loader) model.train() end = time.time() for iteration, (images, targets) in enumerate(data_loader): iteration = iteration + 1 images = images.to(device) targets = targets.to(device) outputs = model(images) loss = criterion(outputs, targets) acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) meters.update( size=images.size(0), loss=loss, top1=acc1, top5=acc5, ) optimizer.zero_grad() loss.backward() optimizer.step() batch_time = time.time() - end end = time.time() meters.update(size=1, time=batch_time) eta_seconds = meters.time.avg * (max_iter - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) tb_idx = epoch * max_iter + iteration if get_rank() == 0: tb_writer.add_scalar('train/loss', loss.item(), tb_idx) tb_writer.add_scalars('train/acc', { 'acc1': acc1.item(), 'acc5': acc5.item() }, tb_idx) if iteration % print_iter_period == 0 or iteration == max_iter: logger.info( meters.delimiter.join([ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ]).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer.param_groups[0]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, ))
def main(): parser = argparse.ArgumentParser( description="PyTorch Classification Training.") parser.add_argument( "--config-file", default="", metavar="FILE", help="path to config file", type=str, ) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() if cfg.MODEL.DEVICE == "cuda" and cfg.CUDA_VISIBLE_DEVICES is not "": os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CUDA_VISIBLE_DEVICES num_gpus = int( os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 distributed = num_gpus > 1 if distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() logger = setup_logger("Classification", "", get_rank()) logger.info("Using {} GPUs".format(num_gpus)) logger.info(cfg) logger.info("Collecting env info (might take some time)") logger.info("\n" + get_pretty_env_info()) acc = run_test(cfg, args.local_rank, distributed) save_dict_data(acc, os.path.join(cfg.OUTPUT_DIR, "acc.txt")) print_dict_data(acc)
def main(): parser = argparse.ArgumentParser( description="PyTorch Classification Training.") parser.add_argument( "--config-file", default="", metavar="FILE", help="path to config file", type=str, ) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( "--skip-test", dest="skip_test", help="Do not test the final model", action="store_true", ) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() if cfg.MODEL.DEVICE == "cuda" and cfg.CUDA_VISIBLE_DEVICES is not "": os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CUDA_VISIBLE_DEVICES num_gpus = int( os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 distributed = num_gpus > 1 if distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() # create tensorboard writer output_dir = cfg.OUTPUT_DIR tb_dir = os.path.join(output_dir, 'tb_log') if get_rank() == 0 and output_dir: mkdir(output_dir) tb_writer = SummaryWriter(tb_dir) logger = setup_logger("Classification", output_dir, get_rank()) logger.info("Using {} GPUs".format(num_gpus)) logger.info(args) logger.info("Collecting env info (might take some time)") logger.info("\n" + get_pretty_env_info()) logger.info("Loaded configuration file {}".format(args.config_file)) with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) output_config_path = os.path.join(cfg.OUTPUT_DIR, "config.yaml") logger.info("Saving config into: {}".format(output_config_path)) save_config(cfg, output_config_path) model = run_train(cfg, args.local_rank, distributed, tb_writer) if not args.skip_test: acc = run_test(cfg, args.local_rank, distributed, model) save_dict_data(acc, os.path.join(cfg.OUTPUT_DIR, "acc.txt")) print_dict_data(acc)
def run_train( cfg, local_rank, distributed, tb_writer, ): logger = logging.getLogger("Classification.trainer") model = build_classification_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) optimizer = make_optimizer(cfg, model) criterion = make_criterion(cfg, device) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False, ) checkpoint = load_checkpoint_from_cfg(cfg, model, optimizer) start_epoch = checkpoint[ "epoch"] if checkpoint is not None and "epoch" in checkpoint else 0 save_to_disk = get_rank() == 0 max_epoch = cfg.SOLVER.MAX_EPOCH save_epoch_period = cfg.SAVE_EPOCH_PERIOD test_epoch_peroid = cfg.TEST_EPOCH_PERIOD print_iter_period = cfg.PRINT_ITER_PERIOD train_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, ) if test_epoch_peroid > 0: val_loader = make_data_loader( cfg, is_train=False, is_distributed=distributed, ) else: val_loader = None time_meter = AverageMeter("epoch_time") start_training_time = time.time() end = time.time() logger.info("Start training") for epoch in range(start_epoch, max_epoch): logger.info("Epoch {}".format(epoch + 1)) adjust_learning_rate(cfg, optimizer, epoch) train( model, epoch, max_epoch, train_loader, optimizer, criterion, device, print_iter_period, logger, tb_writer, ) if save_to_disk and ((epoch + 1) % save_epoch_period == 0 or (epoch + 1) == max_epoch): state = { "epoch": epoch + 1, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), } is_final = True if (epoch + 1) == max_epoch else False save_checkpoint(state, cfg.OUTPUT_DIR, epoch + 1, is_final) if val_loader is not None and (epoch + 1) % test_epoch_peroid == 0: acc = inference(model, val_loader, device) if acc is not None: logger.info("Top1 accuracy: {}. Top5 accuracy: {}.".format( acc["top1"], acc["top5"])) tb_writer.add_scalar('Test Accuracy', acc["top1"], epoch + 1) epoch_time = time.time() - end end = time.time() time_meter.update(epoch_time) eta_seconds = time_meter.avg * (max_epoch - epoch - 1) epoch_string = str(datetime.timedelta(seconds=int(epoch_time))) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) logger.info("Epoch time-consuming: {}. Eta: {}.\n".format( epoch_string, eta_string)) synchronize() total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / epoch)".format( total_time_str, total_training_time / (max_epoch))) return model