Ejemplo n.º 1
0
def train(model, epoch, max_epoch, data_loader, optimizer, criterion, device,
          print_iter_period, logger, tb_writer):
    meters = MetricLogger()
    max_iter = len(data_loader)
    model.train()
    end = time.time()

    for iteration, (images, targets) in enumerate(data_loader):
        iteration = iteration + 1

        images = images.to(device)
        targets = targets.to(device)

        outputs = model(images)
        loss = criterion(outputs, targets)
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))

        meters.update(
            size=images.size(0),
            loss=loss,
            top1=acc1,
            top5=acc5,
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(size=1, time=batch_time)

        eta_seconds = meters.time.avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        tb_idx = epoch * max_iter + iteration
        if get_rank() == 0:
            tb_writer.add_scalar('train/loss', loss.item(), tb_idx)
            tb_writer.add_scalars('train/acc', {
                'acc1': acc1.item(),
                'acc5': acc5.item()
            }, tb_idx)

        if iteration % print_iter_period == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Classification Training.")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    if cfg.MODEL.DEVICE == "cuda" and cfg.CUDA_VISIBLE_DEVICES is not "":
        os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CUDA_VISIBLE_DEVICES

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    logger = setup_logger("Classification", "", get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(cfg)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + get_pretty_env_info())

    acc = run_test(cfg, args.local_rank, distributed)
    save_dict_data(acc, os.path.join(cfg.OUTPUT_DIR, "acc.txt"))
    print_dict_data(acc)
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Classification Training.")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    if cfg.MODEL.DEVICE == "cuda" and cfg.CUDA_VISIBLE_DEVICES is not "":
        os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CUDA_VISIBLE_DEVICES

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    # create tensorboard writer
    output_dir = cfg.OUTPUT_DIR
    tb_dir = os.path.join(output_dir, 'tb_log')
    if get_rank() == 0 and output_dir:
        mkdir(output_dir)
        tb_writer = SummaryWriter(tb_dir)

    logger = setup_logger("Classification", output_dir, get_rank())

    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + get_pretty_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    output_config_path = os.path.join(cfg.OUTPUT_DIR, "config.yaml")
    logger.info("Saving config into: {}".format(output_config_path))
    save_config(cfg, output_config_path)

    model = run_train(cfg, args.local_rank, distributed, tb_writer)

    if not args.skip_test:
        acc = run_test(cfg, args.local_rank, distributed, model)
        save_dict_data(acc, os.path.join(cfg.OUTPUT_DIR, "acc.txt"))
        print_dict_data(acc)
Ejemplo n.º 4
0
def run_train(
    cfg,
    local_rank,
    distributed,
    tb_writer,
):
    logger = logging.getLogger("Classification.trainer")

    model = build_classification_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    criterion = make_criterion(cfg, device)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            broadcast_buffers=False,
        )

    checkpoint = load_checkpoint_from_cfg(cfg, model, optimizer)

    start_epoch = checkpoint[
        "epoch"] if checkpoint is not None and "epoch" in checkpoint else 0

    save_to_disk = get_rank() == 0

    max_epoch = cfg.SOLVER.MAX_EPOCH
    save_epoch_period = cfg.SAVE_EPOCH_PERIOD
    test_epoch_peroid = cfg.TEST_EPOCH_PERIOD
    print_iter_period = cfg.PRINT_ITER_PERIOD

    train_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
    )

    if test_epoch_peroid > 0:
        val_loader = make_data_loader(
            cfg,
            is_train=False,
            is_distributed=distributed,
        )
    else:
        val_loader = None

    time_meter = AverageMeter("epoch_time")
    start_training_time = time.time()
    end = time.time()

    logger.info("Start training")
    for epoch in range(start_epoch, max_epoch):
        logger.info("Epoch {}".format(epoch + 1))

        adjust_learning_rate(cfg, optimizer, epoch)

        train(
            model,
            epoch,
            max_epoch,
            train_loader,
            optimizer,
            criterion,
            device,
            print_iter_period,
            logger,
            tb_writer,
        )

        if save_to_disk and ((epoch + 1) % save_epoch_period == 0 or
                             (epoch + 1) == max_epoch):

            state = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            is_final = True if (epoch + 1) == max_epoch else False
            save_checkpoint(state, cfg.OUTPUT_DIR, epoch + 1, is_final)

        if val_loader is not None and (epoch + 1) % test_epoch_peroid == 0:
            acc = inference(model, val_loader, device)

            if acc is not None:
                logger.info("Top1 accuracy: {}. Top5 accuracy: {}.".format(
                    acc["top1"], acc["top5"]))

                tb_writer.add_scalar('Test Accuracy', acc["top1"], epoch + 1)

        epoch_time = time.time() - end
        end = time.time()

        time_meter.update(epoch_time)
        eta_seconds = time_meter.avg * (max_epoch - epoch - 1)
        epoch_string = str(datetime.timedelta(seconds=int(epoch_time)))
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
        logger.info("Epoch time-consuming: {}. Eta: {}.\n".format(
            epoch_string, eta_string))

    synchronize()
    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / epoch)".format(
        total_time_str, total_training_time / (max_epoch)))

    return model