Esempio n. 1
0
def validate(model, config, dataloader, **varargs):

    # create tuples for validation
    data_config = config["dataloader"]

    # Switch to eval mode
    model.eval()
    dataloader.batch_sampler.set_epoch(varargs["epoch"])
    # dataloader.dataset.update()

    loss_weights = varargs["loss_weights"]

    loss_meter = AverageMeter(())
    data_time_meter = AverageMeter(())
    batch_time_meter = AverageMeter(())

    data_time = time.time()

    for it, batch in enumerate(dataloader):
        with torch.no_grad():

            #Upload batch
            batch = {
                k: batch[k].cuda(device=varargs["device"], non_blocking=True)
                for k in NETWORK_TRAIN_INPUTS
            }

            data_time_meter.update(torch.tensor(time.time() - data_time))

            batch_time = time.time()

            # Run network
            losses, _ = model(**batch,
                              do_loss=True,
                              do_prediction=True,
                              do_augmentation=True)

            losses = OrderedDict((k, v.mean()) for k, v in losses.items())
            losses = all_reduce_losses(losses)
            loss = sum(w * l for w, l in zip(loss_weights, losses.values()))

            # Update meters
            loss_meter.update(loss.cpu())
            batch_time_meter.update(torch.tensor(time.time() - batch_time))

            del loss, losses, batch

        # Log batch
        if varargs["summary"] is not None and (
                it + 1) % varargs["log_interval"] == 0:
            logging.iteration(
                None, "val", varargs["global_step"], varargs["epoch"] + 1,
                varargs["num_epochs"], it + 1, len(dataloader),
                OrderedDict([("loss", loss_meter),
                             ("data_time", data_time_meter),
                             ("batch_time", batch_time_meter)]))

        data_time = time.time()

    return loss_meter.mean
Esempio n. 2
0
def validate(model, config, dataloader, **varargs):

    # create tuples for validation
    data_config = config["dataloader"]

    distributed.barrier()

    avg_neg_distance = dataloader.dataset.create_epoch_tuples(
        model,
        log_info,
        log_debug,
        output_dim=varargs["output_dim"],
        world_size=varargs["world_size"],
        rank=varargs["rank"],
        device=varargs["device"],
        data_config=data_config)
    distributed.barrier()

    # Switch to eval mode
    model.eval()
    dataloader.batch_sampler.set_epoch(varargs["epoch"])

    loss_weights = varargs["loss_weights"]

    loss_meter = AverageMeter(())
    data_time_meter = AverageMeter(())
    batch_time_meter = AverageMeter(())

    data_time = time.time()

    for it, batch in enumerate(dataloader):
        with torch.no_grad():

            # Upload batch
            for k in NETWORK_INPUTS:
                if isinstance(batch[k][0], PackedSequence):
                    batch[k] = [
                        item.cuda(device=varargs["device"], non_blocking=True)
                        for item in batch[k]
                    ]
                else:
                    batch[k] = batch[k].cuda(device=varargs["device"],
                                             non_blocking=True)

            data_time_meter.update(torch.tensor(time.time() - data_time))

            batch_time = time.time()

            # Run network
            losses, _ = model(**batch, do_loss=True, do_prediction=True)

            losses = OrderedDict((k, v.mean()) for k, v in losses.items())
            losses = all_reduce_losses(losses)
            loss = sum(w * l for w, l in zip(loss_weights, losses.values()))

            # Update meters
            loss_meter.update(loss.cpu())
            batch_time_meter.update(torch.tensor(time.time() - batch_time))

            del loss, losses, batch

        # Log batch
        if varargs["summary"] is not None and (
                it + 1) % varargs["log_interval"] == 0:
            logging.iteration(
                None, "val", varargs["global_step"], varargs["epoch"] + 1,
                varargs["num_epochs"], it + 1, len(dataloader),
                OrderedDict([("loss", loss_meter),
                             ("data_time", data_time_meter),
                             ("batch_time", batch_time_meter)]))

        data_time = time.time()

    return loss_meter.mean
Esempio n. 3
0
def train(model, config, dataloader, optimizer, scheduler, meters, **varargs):

    # Create tuples for training
    data_config = config["dataloader"]

    # Switch to train mode
    model.train()

    dataloader.batch_sampler.set_epoch(varargs["epoch"])
    # dataloader.dataset.update()

    optimizer.zero_grad()

    global_step = varargs["global_step"]
    loss_weights = varargs["loss_weights"]

    data_time_meter = AverageMeter((), meters["loss"].momentum)
    batch_time_meter = AverageMeter((), meters["loss"].momentum)

    data_time = time.time()

    torch.autograd.set_detect_anomaly(True)

    for it, batch in enumerate(dataloader):
        #Upload batch

        batch = {
            k: batch[k].cuda(device=varargs["device"], non_blocking=True)
            for k in NETWORK_TRAIN_INPUTS
        }

        # Measure data loading time
        data_time_meter.update(torch.tensor(time.time() - data_time))

        # Update scheduler
        global_step += 1
        if varargs["batch_update"]:
            scheduler.step(global_step)

        batch_time = time.time()

        # Run network
        optimizer.zero_grad()
        losses, _ = model(**batch, do_loss=True, do_augmentaton=True)
        distributed.barrier()
        losses = OrderedDict((k, v.mean()) for k, v in losses.items())
        losses["loss"] = sum(w * l
                             for w, l in zip(loss_weights, losses.values()))

        losses["loss"].backward()
        optimizer.step()

        # Gather from all workers
        losses = all_reduce_losses(losses)

        # Update meters
        with torch.no_grad():
            for loss_name, loss_value in losses.items():

                meters[loss_name].update(loss_value.cpu())

                if torch.isnan(loss_value).any():
                    input()

        batch_time_meter.update(torch.tensor(time.time() - batch_time))

        # Clean-up
        del batch, losses

        # Log
        if varargs["summary"] is not None and (
                it + 1) % varargs["log_interval"] == 0:
            logging.iteration(
                varargs["summary"], "train", global_step, varargs["epoch"] + 1,
                varargs["num_epochs"], it + 1, len(dataloader),
                OrderedDict([("lr_body", scheduler.get_lr()[0] * 1e6),
                             ("loss", meters["loss"]),
                             ("epipolar_loss", meters["epipolar_loss"]),
                             ("consistency_loss", meters["consistency_loss"]),
                             ("data_time", data_time_meter),
                             ("batch_time", batch_time_meter)]))

        data_time = time.time()

    return global_step
Esempio n. 4
0
def train(model, config, dataloader, optimizer, scheduler, meters, **varargs):

    # Create tuples for training
    data_config = config["dataloader"]

    distributed.barrier()

    avg_neg_distance = dataloader.dataset.create_epoch_tuples(
        model,
        log_info,
        log_debug,
        output_dim=varargs["output_dim"],
        world_size=varargs["world_size"],
        rank=varargs["rank"],
        device=varargs["device"],
        data_config=data_config)
    distributed.barrier()

    # switch to train mode
    model.train()
    dataloader.batch_sampler.set_epoch(varargs["epoch"])
    optimizer.zero_grad()
    global_step = varargs["global_step"]
    loss_weights = varargs["loss_weights"]

    data_time_meter = AverageMeter((), meters["loss"].momentum)
    batch_time_meter = AverageMeter((), meters["loss"].momentum)

    data_time = time.time()

    for it, batch in enumerate(dataloader):

        # Upload batch
        for k in NETWORK_INPUTS:
            if isinstance(batch[k][0], PackedSequence):
                batch[k] = [
                    item.cuda(device=varargs["device"], non_blocking=True)
                    for item in batch[k]
                ]
            else:
                batch[k] = batch[k].cuda(device=varargs["device"],
                                         non_blocking=True)

        # Measure data loading time
        data_time_meter.update(torch.tensor(time.time() - data_time))

        # Update scheduler
        global_step += 1
        if varargs["batch_update"]:
            scheduler.step(global_step)

        batch_time = time.time()

        # Run network
        losses, _ = model(**batch, do_loss=True, do_augmentaton=True)
        distributed.barrier()

        losses = OrderedDict((k, v.mean()) for k, v in losses.items())
        losses["loss"] = sum(w * l
                             for w, l in zip(loss_weights, losses.values()))

        losses["loss"].backward()

        optimizer.step()
        optimizer.zero_grad()

        if (it + 1) % 5 == 0:
            optimizer.step()
            optimizer.zero_grad()

        # Gather from all workers
        losses = all_reduce_losses(losses)

        # Update meters
        with torch.no_grad():
            for loss_name, loss_value in losses.items():
                meters[loss_name].update(loss_value.cpu())

        batch_time_meter.update(torch.tensor(time.time() - batch_time))

        # Clean-up
        del batch, losses

        # Log
        if varargs["summary"] is not None and (
                it + 1) % varargs["log_interval"] == 0:
            logging.iteration(
                varargs["summary"], "train", global_step, varargs["epoch"] + 1,
                varargs["num_epochs"], it + 1, len(dataloader),
                OrderedDict([("lr_body", scheduler.get_lr()[0] * 1e6),
                             ("lr_ret", scheduler.get_lr()[1] * 1e6),
                             ("loss", meters["loss"]),
                             ("ret_loss", meters["ret_loss"]),
                             ("data_time", data_time_meter),
                             ("batch_time", batch_time_meter)]))

        data_time = time.time()

    return global_step