Ejemplo n.º 1
0
def main(args):

    # Initialize multi-processing
    distributed.init_process_group(backend='nccl', init_method='env://')
    device_id, device = args.local_rank, torch.device(args.local_rank)
    rank, world_size = distributed.get_rank(), distributed.get_world_size()
    torch.cuda.set_device(device_id)

    # Load configuration
    config = make_config(args)

    # Experiment Path
    exp_dir = make_dir(config, args.directory)

    # Initialize logging
    if rank == 0:
        logging.init(exp_dir, "training" if not args.eval else "eval")
        summary = tensorboard.SummaryWriter(args.directory)
    else:
        summary = None

    body_config = config["body"]
    optimizer_config = config["optimizer"]

    # Load data
    train_dataloader, val_dataloader = make_dataloader(args, config, rank,
                                                       world_size)

    # Initialize model
    if body_config.getboolean("pretrained"):
        log_debug("Use pre-trained model %s", body_config.get("arch"))
    else:
        log_debug("Initialize model to train from scratch %s".body_config.get(
            "arch"))

    # Load model
    model, output_dim = make_model(args, config)
    print(model)

    # Resume / Pre_Train
    if args.resume:
        assert not args.pre_train, "resume and pre_train are mutually exclusive"
        log_debug("Loading snapshot from %s", args.resume)
        snapshot = resume_from_snapshot(
            model, args.resume,
            ["body", "local_head_coarse", "local_head_fine"])
    elif args.pre_train:
        assert not args.resume, "resume and pre_train are mutually exclusive"
        log_debug("Loading pre-trained model from %s", args.pre_train)
        pre_train_from_snapshots(
            model, args.pre_train,
            ["body", "local_head_coarse", "local_head_fine"])
    else:
        #assert not args.eval, "--resume is needed in eval mode"
        snapshot = None

    # Init GPU stuff
    torch.backends.cudnn.benchmark = config["general"].getboolean(
        "cudnn_benchmark")
    model = DistributedDataParallel(model.cuda(device),
                                    device_ids=[device_id],
                                    output_device=device_id,
                                    find_unused_parameters=True)

    # Create optimizer & scheduler
    optimizer, scheduler, parameters, batch_update, total_epochs = make_optimizer(
        model, config, epoch_length=len(train_dataloader))
    if args.resume:
        optimizer.load_state_dict(snapshot["state_dict"]["optimizer"])

    # Training loop
    momentum = 1. - 1. / len(train_dataloader)
    meters = {
        "loss": AverageMeter((), momentum),
        "epipolar_loss": AverageMeter((), momentum),
        "consistency_loss": AverageMeter((), momentum),
    }

    if args.resume:
        start_epoch = snapshot["training_meta"]["epoch"] + 1
        best_score = snapshot["training_meta"]["best_score"]
        global_step = snapshot["training_meta"]["global_step"]

        for name, meter in meters.items():
            meter.load_state_dict(snapshot["state_dict"][name + "_meter"])
        del snapshot
    else:
        start_epoch = 0
        best_score = {
            "val": 1000.0,
            "test": 0.0,
        }
        global_step = 0

    # Optional: evaluation only:
    if args.eval:
        log_info("Evaluation epoch %d", start_epoch - 1)

        test(args,
             config,
             model,
             rank=rank,
             world_size=world_size,
             output_dim=output_dim,
             device=device)

        log_info("Evaluation Done ..... ")

        exit(0)

    for epoch in range(start_epoch, total_epochs):

        log_info("Starting epoch %d", epoch + 1)

        if not batch_update:
            scheduler.step(epoch)

        score = {}

        # Run training
        global_step = train(
            model,
            config,
            train_dataloader,
            optimizer,
            scheduler,
            meters,
            summary=summary,
            batch_update=batch_update,
            log_interval=config["general"].getint("log_interval"),
            epoch=epoch,
            num_epochs=total_epochs,
            global_step=global_step,
            output_dim=output_dim,
            world_size=world_size,
            rank=rank,
            device=device,
            loss_weights=optimizer_config.getstruct("loss_weights"))

        # Save snapshot (only on rank 0)
        if rank == 0:
            snapshot_file = path.join(exp_dir,
                                      "model_{}.pth.tar".format(epoch))

            log_debug("Saving snapshot to %s", snapshot_file)

            meters_out_dict = {
                k + "_meter": v.state_dict()
                for k, v in meters.items()
            }

            save_snapshot(
                snapshot_file,
                config,
                epoch,
                0,
                best_score,
                global_step,
                body=model.module.body.state_dict(),
                local_head_coarse=model.module.local_head_coarse.state_dict(),
                local_head_fine=model.module.local_head_fine.state_dict(),
                optimizer=optimizer.state_dict(),
                **meters_out_dict)

        # Run validation
        if (epoch + 1) % config["general"].getint("val_interval") == 0:
            log_info("Validating epoch %d", epoch + 1)

            score['val'] = validate(
                model,
                config,
                val_dataloader,
                summary=summary,
                batch_update=batch_update,
                log_interval=config["general"].getint("log_interval"),
                epoch=epoch,
                num_epochs=total_epochs,
                global_step=global_step,
                output_dim=output_dim,
                world_size=world_size,
                rank=rank,
                device=device,
                loss_weights=optimizer_config.getstruct("loss_weights"))

        # Run Test
        if (epoch + 1) % config["general"].getint("test_interval") == 0:
            log_info("Testing epoch %d", epoch + 1)

            score['test'] = test(args,
                                 config,
                                 model,
                                 rank=rank,
                                 world_size=world_size,
                                 output_dim=output_dim,
                                 device=device)

            # Update the score on the last saved snapshot
            if rank == 0:
                snapshot = torch.load(snapshot_file, map_location="cpu")
                snapshot["training_meta"]["last_score"] = score
                torch.save(snapshot, snapshot_file)
                del snapshot

            if score['test'] > best_score['test']:
                best_score = score
                if rank == 0:
                    shutil.copy(snapshot_file,
                                path.join(exp_dir, "test_model_best.pth.tar"))
Ejemplo n.º 2
0
def validate(model, config, dataloader, **varargs):

    # create tuples for validation
    data_config = config["dataloader"]

    # Switch to eval mode
    model.eval()
    dataloader.batch_sampler.set_epoch(varargs["epoch"])
    # dataloader.dataset.update()

    loss_weights = varargs["loss_weights"]

    loss_meter = AverageMeter(())
    data_time_meter = AverageMeter(())
    batch_time_meter = AverageMeter(())

    data_time = time.time()

    for it, batch in enumerate(dataloader):
        with torch.no_grad():

            #Upload batch
            batch = {
                k: batch[k].cuda(device=varargs["device"], non_blocking=True)
                for k in NETWORK_TRAIN_INPUTS
            }

            data_time_meter.update(torch.tensor(time.time() - data_time))

            batch_time = time.time()

            # Run network
            losses, _ = model(**batch,
                              do_loss=True,
                              do_prediction=True,
                              do_augmentation=True)

            losses = OrderedDict((k, v.mean()) for k, v in losses.items())
            losses = all_reduce_losses(losses)
            loss = sum(w * l for w, l in zip(loss_weights, losses.values()))

            # Update meters
            loss_meter.update(loss.cpu())
            batch_time_meter.update(torch.tensor(time.time() - batch_time))

            del loss, losses, batch

        # Log batch
        if varargs["summary"] is not None and (
                it + 1) % varargs["log_interval"] == 0:
            logging.iteration(
                None, "val", varargs["global_step"], varargs["epoch"] + 1,
                varargs["num_epochs"], it + 1, len(dataloader),
                OrderedDict([("loss", loss_meter),
                             ("data_time", data_time_meter),
                             ("batch_time", batch_time_meter)]))

        data_time = time.time()

    return loss_meter.mean
Ejemplo n.º 3
0
def validate(model, config, dataloader, **varargs):

    # create tuples for validation
    data_config = config["dataloader"]

    distributed.barrier()

    avg_neg_distance = dataloader.dataset.create_epoch_tuples(
        model,
        log_info,
        log_debug,
        output_dim=varargs["output_dim"],
        world_size=varargs["world_size"],
        rank=varargs["rank"],
        device=varargs["device"],
        data_config=data_config)
    distributed.barrier()

    # Switch to eval mode
    model.eval()
    dataloader.batch_sampler.set_epoch(varargs["epoch"])

    loss_weights = varargs["loss_weights"]

    loss_meter = AverageMeter(())
    data_time_meter = AverageMeter(())
    batch_time_meter = AverageMeter(())

    data_time = time.time()

    for it, batch in enumerate(dataloader):
        with torch.no_grad():

            # Upload batch
            for k in NETWORK_INPUTS:
                if isinstance(batch[k][0], PackedSequence):
                    batch[k] = [
                        item.cuda(device=varargs["device"], non_blocking=True)
                        for item in batch[k]
                    ]
                else:
                    batch[k] = batch[k].cuda(device=varargs["device"],
                                             non_blocking=True)

            data_time_meter.update(torch.tensor(time.time() - data_time))

            batch_time = time.time()

            # Run network
            losses, _ = model(**batch, do_loss=True, do_prediction=True)

            losses = OrderedDict((k, v.mean()) for k, v in losses.items())
            losses = all_reduce_losses(losses)
            loss = sum(w * l for w, l in zip(loss_weights, losses.values()))

            # Update meters
            loss_meter.update(loss.cpu())
            batch_time_meter.update(torch.tensor(time.time() - batch_time))

            del loss, losses, batch

        # Log batch
        if varargs["summary"] is not None and (
                it + 1) % varargs["log_interval"] == 0:
            logging.iteration(
                None, "val", varargs["global_step"], varargs["epoch"] + 1,
                varargs["num_epochs"], it + 1, len(dataloader),
                OrderedDict([("loss", loss_meter),
                             ("data_time", data_time_meter),
                             ("batch_time", batch_time_meter)]))

        data_time = time.time()

    return loss_meter.mean
Ejemplo n.º 4
0
def train(model, config, dataloader, optimizer, scheduler, meters, **varargs):

    # Create tuples for training
    data_config = config["dataloader"]

    # Switch to train mode
    model.train()

    dataloader.batch_sampler.set_epoch(varargs["epoch"])
    # dataloader.dataset.update()

    optimizer.zero_grad()

    global_step = varargs["global_step"]
    loss_weights = varargs["loss_weights"]

    data_time_meter = AverageMeter((), meters["loss"].momentum)
    batch_time_meter = AverageMeter((), meters["loss"].momentum)

    data_time = time.time()

    torch.autograd.set_detect_anomaly(True)

    for it, batch in enumerate(dataloader):
        #Upload batch

        batch = {
            k: batch[k].cuda(device=varargs["device"], non_blocking=True)
            for k in NETWORK_TRAIN_INPUTS
        }

        # Measure data loading time
        data_time_meter.update(torch.tensor(time.time() - data_time))

        # Update scheduler
        global_step += 1
        if varargs["batch_update"]:
            scheduler.step(global_step)

        batch_time = time.time()

        # Run network
        optimizer.zero_grad()
        losses, _ = model(**batch, do_loss=True, do_augmentaton=True)
        distributed.barrier()
        losses = OrderedDict((k, v.mean()) for k, v in losses.items())
        losses["loss"] = sum(w * l
                             for w, l in zip(loss_weights, losses.values()))

        losses["loss"].backward()
        optimizer.step()

        # Gather from all workers
        losses = all_reduce_losses(losses)

        # Update meters
        with torch.no_grad():
            for loss_name, loss_value in losses.items():

                meters[loss_name].update(loss_value.cpu())

                if torch.isnan(loss_value).any():
                    input()

        batch_time_meter.update(torch.tensor(time.time() - batch_time))

        # Clean-up
        del batch, losses

        # Log
        if varargs["summary"] is not None and (
                it + 1) % varargs["log_interval"] == 0:
            logging.iteration(
                varargs["summary"], "train", global_step, varargs["epoch"] + 1,
                varargs["num_epochs"], it + 1, len(dataloader),
                OrderedDict([("lr_body", scheduler.get_lr()[0] * 1e6),
                             ("loss", meters["loss"]),
                             ("epipolar_loss", meters["epipolar_loss"]),
                             ("consistency_loss", meters["consistency_loss"]),
                             ("data_time", data_time_meter),
                             ("batch_time", batch_time_meter)]))

        data_time = time.time()

    return global_step
Ejemplo n.º 5
0
def train(model, config, dataloader, optimizer, scheduler, meters, **varargs):

    # Create tuples for training
    data_config = config["dataloader"]

    distributed.barrier()

    avg_neg_distance = dataloader.dataset.create_epoch_tuples(
        model,
        log_info,
        log_debug,
        output_dim=varargs["output_dim"],
        world_size=varargs["world_size"],
        rank=varargs["rank"],
        device=varargs["device"],
        data_config=data_config)
    distributed.barrier()

    # switch to train mode
    model.train()
    dataloader.batch_sampler.set_epoch(varargs["epoch"])
    optimizer.zero_grad()
    global_step = varargs["global_step"]
    loss_weights = varargs["loss_weights"]

    data_time_meter = AverageMeter((), meters["loss"].momentum)
    batch_time_meter = AverageMeter((), meters["loss"].momentum)

    data_time = time.time()

    for it, batch in enumerate(dataloader):

        # Upload batch
        for k in NETWORK_INPUTS:
            if isinstance(batch[k][0], PackedSequence):
                batch[k] = [
                    item.cuda(device=varargs["device"], non_blocking=True)
                    for item in batch[k]
                ]
            else:
                batch[k] = batch[k].cuda(device=varargs["device"],
                                         non_blocking=True)

        # Measure data loading time
        data_time_meter.update(torch.tensor(time.time() - data_time))

        # Update scheduler
        global_step += 1
        if varargs["batch_update"]:
            scheduler.step(global_step)

        batch_time = time.time()

        # Run network
        losses, _ = model(**batch, do_loss=True, do_augmentaton=True)
        distributed.barrier()

        losses = OrderedDict((k, v.mean()) for k, v in losses.items())
        losses["loss"] = sum(w * l
                             for w, l in zip(loss_weights, losses.values()))

        losses["loss"].backward()

        optimizer.step()
        optimizer.zero_grad()

        if (it + 1) % 5 == 0:
            optimizer.step()
            optimizer.zero_grad()

        # Gather from all workers
        losses = all_reduce_losses(losses)

        # Update meters
        with torch.no_grad():
            for loss_name, loss_value in losses.items():
                meters[loss_name].update(loss_value.cpu())

        batch_time_meter.update(torch.tensor(time.time() - batch_time))

        # Clean-up
        del batch, losses

        # Log
        if varargs["summary"] is not None and (
                it + 1) % varargs["log_interval"] == 0:
            logging.iteration(
                varargs["summary"], "train", global_step, varargs["epoch"] + 1,
                varargs["num_epochs"], it + 1, len(dataloader),
                OrderedDict([("lr_body", scheduler.get_lr()[0] * 1e6),
                             ("lr_ret", scheduler.get_lr()[1] * 1e6),
                             ("loss", meters["loss"]),
                             ("ret_loss", meters["ret_loss"]),
                             ("data_time", data_time_meter),
                             ("batch_time", batch_time_meter)]))

        data_time = time.time()

    return global_step