示例#1
0
def main(args):
    global min_loss

    # Initialize multi-processing
    distributed.init_process_group(backend='nccl', init_method='env://')
    device_id, device = args.local_rank, torch.device(args.local_rank)
    rank, world_size = distributed.get_rank(), distributed.get_world_size()
    torch.cuda.set_device(device_id)

    # Load configuration
    config = make_config(args)

    # Initialize logging
    if rank == 0:
        logging.init(args.resume, "testing")
    else:
        summary = None

    # Load model
    model, _, output_dim = make_model(args, config)

    # Init GPU stuff
    torch.backends.cudnn.benchmark = config["general"].getboolean(
        "cudnn_benchmark")
    model = model.cuda(device)

    # Resume / Pre_Train
    log_info("Loading snapshots from %s", args.resume)

    models = sorted(glob(args.resume + '/model*.pth.tar'), reverse=True)

    print(models)
    for model_i in models:

        snapshot = resume_from_snapshot(model, model_i, ["body", "ret_head"])

        epoch = snapshot["training_meta"]["epoch"]

        log_info("Evaluation epoch %d", epoch)

        test(args,
             config,
             model,
             rank=rank,
             world_size=world_size,
             output_dim=output_dim,
             device=device)

    log_info("Evaluation Done ..... ")
示例#2
0
def main(args):

    # Initialize multi-processing
    distributed.init_process_group(backend='nccl', init_method='env://')
    device_id, device = args.local_rank, torch.device(args.local_rank)
    rank, world_size = distributed.get_rank(), distributed.get_world_size()
    torch.cuda.set_device(device_id)

    # Load configuration
    config = make_config(args)

    # Experiment Path
    exp_dir = make_dir(config, args.directory)

    # Initialize logging
    if rank == 0:
        logging.init(exp_dir, "training" if not args.eval else "eval")
        summary = tensorboard.SummaryWriter(args.directory)
    else:
        summary = None

    body_config = config["body"]
    optimizer_config = config["optimizer"]

    # Load data
    train_dataloader, val_dataloader = make_dataloader(args, config, rank,
                                                       world_size)

    # Initialize model
    if body_config.getboolean("pretrained"):
        log_debug("Use pre-trained model %s", body_config.get("arch"))
    else:
        log_debug("Initialize model to train from scratch %s".body_config.get(
            "arch"))

    # Load model
    model, output_dim = make_model(args, config)
    print(model)

    # Resume / Pre_Train
    if args.resume:
        assert not args.pre_train, "resume and pre_train are mutually exclusive"
        log_debug("Loading snapshot from %s", args.resume)
        snapshot = resume_from_snapshot(
            model, args.resume,
            ["body", "local_head_coarse", "local_head_fine"])
    elif args.pre_train:
        assert not args.resume, "resume and pre_train are mutually exclusive"
        log_debug("Loading pre-trained model from %s", args.pre_train)
        pre_train_from_snapshots(
            model, args.pre_train,
            ["body", "local_head_coarse", "local_head_fine"])
    else:
        #assert not args.eval, "--resume is needed in eval mode"
        snapshot = None

    # Init GPU stuff
    torch.backends.cudnn.benchmark = config["general"].getboolean(
        "cudnn_benchmark")
    model = DistributedDataParallel(model.cuda(device),
                                    device_ids=[device_id],
                                    output_device=device_id,
                                    find_unused_parameters=True)

    # Create optimizer & scheduler
    optimizer, scheduler, parameters, batch_update, total_epochs = make_optimizer(
        model, config, epoch_length=len(train_dataloader))
    if args.resume:
        optimizer.load_state_dict(snapshot["state_dict"]["optimizer"])

    # Training loop
    momentum = 1. - 1. / len(train_dataloader)
    meters = {
        "loss": AverageMeter((), momentum),
        "epipolar_loss": AverageMeter((), momentum),
        "consistency_loss": AverageMeter((), momentum),
    }

    if args.resume:
        start_epoch = snapshot["training_meta"]["epoch"] + 1
        best_score = snapshot["training_meta"]["best_score"]
        global_step = snapshot["training_meta"]["global_step"]

        for name, meter in meters.items():
            meter.load_state_dict(snapshot["state_dict"][name + "_meter"])
        del snapshot
    else:
        start_epoch = 0
        best_score = {
            "val": 1000.0,
            "test": 0.0,
        }
        global_step = 0

    # Optional: evaluation only:
    if args.eval:
        log_info("Evaluation epoch %d", start_epoch - 1)

        test(args,
             config,
             model,
             rank=rank,
             world_size=world_size,
             output_dim=output_dim,
             device=device)

        log_info("Evaluation Done ..... ")

        exit(0)

    for epoch in range(start_epoch, total_epochs):

        log_info("Starting epoch %d", epoch + 1)

        if not batch_update:
            scheduler.step(epoch)

        score = {}

        # Run training
        global_step = train(
            model,
            config,
            train_dataloader,
            optimizer,
            scheduler,
            meters,
            summary=summary,
            batch_update=batch_update,
            log_interval=config["general"].getint("log_interval"),
            epoch=epoch,
            num_epochs=total_epochs,
            global_step=global_step,
            output_dim=output_dim,
            world_size=world_size,
            rank=rank,
            device=device,
            loss_weights=optimizer_config.getstruct("loss_weights"))

        # Save snapshot (only on rank 0)
        if rank == 0:
            snapshot_file = path.join(exp_dir,
                                      "model_{}.pth.tar".format(epoch))

            log_debug("Saving snapshot to %s", snapshot_file)

            meters_out_dict = {
                k + "_meter": v.state_dict()
                for k, v in meters.items()
            }

            save_snapshot(
                snapshot_file,
                config,
                epoch,
                0,
                best_score,
                global_step,
                body=model.module.body.state_dict(),
                local_head_coarse=model.module.local_head_coarse.state_dict(),
                local_head_fine=model.module.local_head_fine.state_dict(),
                optimizer=optimizer.state_dict(),
                **meters_out_dict)

        # Run validation
        if (epoch + 1) % config["general"].getint("val_interval") == 0:
            log_info("Validating epoch %d", epoch + 1)

            score['val'] = validate(
                model,
                config,
                val_dataloader,
                summary=summary,
                batch_update=batch_update,
                log_interval=config["general"].getint("log_interval"),
                epoch=epoch,
                num_epochs=total_epochs,
                global_step=global_step,
                output_dim=output_dim,
                world_size=world_size,
                rank=rank,
                device=device,
                loss_weights=optimizer_config.getstruct("loss_weights"))

        # Run Test
        if (epoch + 1) % config["general"].getint("test_interval") == 0:
            log_info("Testing epoch %d", epoch + 1)

            score['test'] = test(args,
                                 config,
                                 model,
                                 rank=rank,
                                 world_size=world_size,
                                 output_dim=output_dim,
                                 device=device)

            # Update the score on the last saved snapshot
            if rank == 0:
                snapshot = torch.load(snapshot_file, map_location="cpu")
                snapshot["training_meta"]["last_score"] = score
                torch.save(snapshot, snapshot_file)
                del snapshot

            if score['test'] > best_score['test']:
                best_score = score
                if rank == 0:
                    shutil.copy(snapshot_file,
                                path.join(exp_dir, "test_model_best.pth.tar"))