Esempio n. 1
0
def test_setup_logger(capsys, dirname):

    trainer = Engine(lambda e, b: None)
    evaluator = Engine(lambda e, b: None)

    fp = os.path.join(dirname, "log")
    assert len(trainer.logger.handlers) == 0
    trainer.logger.addHandler(logging.NullHandler())
    trainer.logger.addHandler(logging.NullHandler())
    trainer.logger.addHandler(logging.NullHandler())

    trainer.logger = setup_logger("trainer", filepath=fp)
    evaluator.logger = setup_logger("evaluator", filepath=fp)

    assert len(trainer.logger.handlers) == 2
    assert len(evaluator.logger.handlers) == 2

    @trainer.on(Events.EPOCH_COMPLETED)
    def _(_):
        evaluator.run([0, 1, 2])

    trainer.run([0, 1, 2, 3, 4, 5], max_epochs=5)

    captured = capsys.readouterr()
    err = captured.err.split("\n")

    with open(fp, "r") as h:
        data = h.readlines()

    for source in [err, data]:
        assert "trainer INFO: Engine run starting with max_epochs=5." in source[
            0]
        assert "evaluator INFO: Engine run starting with max_epochs=1." in source[
            1]
Esempio n. 2
0
def create_evaluator(model, args, name="evaluator"):
    @torch.no_grad()
    def eval_step(engine, batch):
        model.eval()

        images, targets = batch
        images = convert_tensor(images, device=args.device, non_blocking=False)

        with torch_num_threads(1):
            outputs = model(images)

        outputs = convert_tensor(outputs, device="cpu")

        # Store results in engine state.
        results = {
            target["image_id"].item(): output
            for target, output in zip(targets, outputs)
        }
        engine.state.result = results

        return outputs

    evaluator = Engine(eval_step)

    # Configure default engine output logging.
    evaluator.logger = setup_logger(name)

    return evaluator
def test_log_metrics(capsys):
    engine = Engine(lambda e, b: None)
    engine.logger = setup_logger(format="%(message)s")
    engine.run(list(range(100)), max_epochs=2)
    log_metrics(engine, "train")
    captured = capsys.readouterr()
    assert captured.err.split("\n")[-2] == "train [2/200]: {}"
Esempio n. 4
0
def create_trainer(model, optimizer, criterion, train_sampler, config, logger):
    prepare_batch = config.prepare_batch
    device = config.device

    # Setup trainer
    accumulation_steps = getattr(config, "accumulation_steps", 1)
    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    def train_update_function(engine, batch):

        model.train()

        x, y = prepare_batch(batch, device=device, non_blocking=True)
        y_pred = model(x)
        y_pred = model_output_transform(y_pred)
        loss = criterion(y_pred, y) / accumulation_steps

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return {
            "supervised batch loss": loss.item(),
        }

    output_names = getattr(config, "output_names", ["supervised batch loss"])
    lr_scheduler = config.lr_scheduler

    trainer = Engine(train_update_function)
    trainer.logger = logger

    to_save = {
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
        "trainer": trainer,
        "amp": amp
    }

    save_every_iters = getattr(config, "save_every_iters", 1000)

    common.setup_common_training_handlers(
        trainer,
        train_sampler,
        to_save=to_save,
        save_every_iters=save_every_iters,
        output_path=config.output_path.as_posix(),
        lr_scheduler=lr_scheduler,
        with_gpu_stats=True,
        output_names=output_names,
        with_pbars=False,
    )

    common.ProgressBar(persist=False).attach(trainer, metric_names="all")

    return trainer
Esempio n. 5
0
def create_trainer(model, optimizer, args):
    def train_step(engine, batch):
        model.train()

        images, targets = convert_tensor(batch,
                                         device=args.device,
                                         non_blocking=False)
        loss_dict = model(images, targets)
        losses = sum(loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        # We only need the scalar values.
        loss_values = {k: v.item() for k, v in loss_dict.items()}

        return loss_values

    trainer = Engine(train_step)

    # Configure default engine output logging.
    trainer.logger = setup_logger("trainer")

    # Compute running averages of training metrics.
    metrics = training_metrics()
    for name, metric in metrics.items():
        running_average(metric).attach(trainer, name)

    return trainer
Esempio n. 6
0
def create_trainer(model, optimizer, criterion, train_sampler, config, logger):
    prepare_batch = config.prepare_batch
    device = config.device

    # Setup trainer
    accumulation_steps = getattr(config, "accumulation_steps", 1)
    model_output_transform = getattr(config, "model_output_transform", lambda x: x)

    def train_update_function(engine, batch):

        model.train()

        x, y = prepare_batch(batch, device=device, non_blocking=True)
        y_pred = model(x)
        y_pred = model_output_transform(y_pred)
        loss = criterion(y_pred, y)

        if isinstance(loss, Mapping):
            assert "supervised batch loss" in loss
            loss_dict = loss
            output = {k: v.item() for k, v in loss_dict.items()}
            loss = loss_dict["supervised batch loss"] / accumulation_steps
        else:
            output = {"supervised batch loss": loss.item()}

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return output

    output_names = getattr(config, "output_names", ["supervised batch loss",])
    lr_scheduler = config.lr_scheduler

    trainer = Engine(train_update_function)
    trainer.logger = logger

    to_save = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, "trainer": trainer, "amp": amp}

    save_every_iters = getattr(config, "save_every_iters", 1000)

    common.setup_common_training_handlers(
        trainer,
        train_sampler,
        to_save=to_save,
        save_every_iters=save_every_iters,
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        with_gpu_stats=exp_tracking.has_mlflow,
        output_names=output_names,
        with_pbars=False,
    )

    if idist.get_rank() == 0:
        common.ProgressBar(persist=False).attach(trainer, metric_names="all")

    return trainer
Esempio n. 7
0
def create_classification_training_loop(model: nn.Module,
                                        cfg: DictConfig,
                                        name: str,
                                        device="cpu") -> Engine:

    # Network
    model.to(device)

    # Optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=cfg.mode.train.learning_rate,
                                momentum=0.8)

    # Loss
    loss_fn = torch.nn.NLLLoss()

    def _update(engine, batch):

        ########################################################################
        # Modify the logic of your training
        ########################################################################
        model.train()
        optimizer.zero_grad()
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()

        # Anything you want to log or use to compute something to log must be returned in this dictionary
        update_dict = {
            "nll": loss.item(),
            "nll_2": loss.item() + 0.5,
            "y_pred": y_pred,
            "x": x,
            "y": y,
        }

        return update_dict

    engine = Engine(_update)

    # Required to set up logging
    engine.logger = setup_logger(name=name)

    # (Optional) Specify training metrics. "output_transform" used to select items from "update_dict" needed by metrics
    # Collecting metrics over training set is not recommended
    # https://pytorch.org/ignite/metrics.html#ignite.metrics.Loss

    return engine
Esempio n. 8
0
def create_classification_evaluation_loop(model: nn.Module,
                                          cfg: DictConfig,
                                          name: str,
                                          device="cpu") -> Engine:

    # Loss
    loss_fn = torch.nn.NLLLoss()

    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_pred = model(x)

        # Anything you want to log must be returned in this dictionary
        infer_dict = {"y_pred": y_pred, "y": y}

        return infer_dict

    engine = Engine(_inference)

    # Required to set up logging
    engine.logger = setup_logger(name=name)

    # Specify evaluation metrics. "output_transform" used to select items from "infer_dict" needed by metrics
    # https://pytorch.org/ignite/metrics.html#ignite.metrics.
    metrics = {
        "accuracy":
        Accuracy(output_transform=lambda infer_dict:
                 (infer_dict["y_pred"], infer_dict["y"])),
        "nll":
        Loss(
            loss_fn,
            output_transform=lambda infer_dict:
            (infer_dict["y_pred"], infer_dict["y"]),
        ),
    }

    for name, metric in metrics.items():
        metric.attach(engine, name)

    return engine
Esempio n. 9
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    cutmix_beta = config["cutmix_beta"]
    cutmix_prob = config["cutmix_prob"]
    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()

        with autocast(enabled=with_amp):
            r = torch.rand(1).item()
            if cutmix_beta > 0 and r < cutmix_prob:
                output, loss = utils.cutmix_forward(model, x, criterion, y,
                                                    cutmix_beta)
            else:
                output = model(x)
                loss = criterion(output, y)

        optimizer.zero_grad()
        scaler.scale(loss).backward()

        if idist.backend() == "horovod":
            optimizer.synchronize()
            with optimizer.skip_synchronize():
                scaler.step(optimizer)
                scaler.update()
        else:
            scaler.step(optimizer)
            scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    if config["with_pbar"] and idist.get_rank() == 0:
        ProgressBar().attach(trainer)

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert (checkpoint_fp.exists()
                ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Esempio n. 10
0
def create_regression_evaluation_loop(model: nn.Module,
                                      cfg: DictConfig,
                                      name: str,
                                      device="cpu") -> Engine:

    # Loss
    loss_fn = torch.nn.MSELoss()

    mape_metric = Mape(output_transform=lambda infer_dict:
                       (infer_dict["y_pred"], infer_dict["y"]))

    def _inference(engine, batch):

        model.eval()

        with torch.no_grad():

            x, y = batch
            x, y = x.to(device), y.to(device)
            y_pred = model(x)

            factor = 1777.0

            y = y * factor
            y_pred = y_pred * factor
            mse_val = loss_fn(y, y_pred).item()

        infer_dict = {
            "loss": mse_val,
            "y_pred": y_pred / factor,
            "x": x / factor,
            "y": y / factor,
            "y_pred_times_factor": y_pred,
            "y_times_factor": y,
            "ypred_first": np.vstack([y[0], y_pred[0]]),
        }

        return infer_dict

    engine = Engine(_inference)

    engine.logger = setup_logger(name=name)

    metrics = {
        "loss":
        Loss(
            loss_fn,
            output_transform=lambda infer_dict: (
                infer_dict["y_pred_times_factor"],
                infer_dict["y_times_factor"],
            ),
        ),
        "mape":
        mape_metric,
    }

    for name, metric in metrics.items():
        print("attaching metric: " + name)
        metric.attach(engine, name)

    return engine
Esempio n. 11
0
def create_regression_training_loop(model: nn.Module,
                                    cfg: DictConfig,
                                    name: str,
                                    device="cpu") -> Engine:

    # Network
    model.to(device)

    # Optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=cfg.mode.train.learning_rate)

    # Loss
    loss_fn = torch.nn.MSELoss()

    mape_metric = Mape(output_transform=lambda infer_dict:
                       (infer_dict["y_pred"], infer_dict["y"]))

    def _update(engine, batch):

        model.train()

        optimizer.zero_grad()
        x, y = batch
        x, y = x.to(device), y.to(device)

        y_pred = model(x)

        loss = loss_fn(10 * y_pred, 10 * y)

        loss.backward()
        optimizer.step()

        factor = 1777.0

        mse_val = loss_fn(factor * y_pred, factor * y).item()
        y_hat = y[0].cpu().detach().numpy() * factor
        y_pred_hat = y_pred[0].cpu().detach().numpy() * factor

        # Anything you want to log must be returned in this dictionary
        update_dict = {
            "loss": mse_val,
            "y_pred": y_pred * factor,
            "ypred_first": [y_hat, y_pred_hat],
            "y": y * factor,
        }

        return update_dict

    engine = Engine(_update)

    # Required to set up logging
    engine.logger = setup_logger(name=name)

    metrics = {"mape": mape_metric}

    for name, metric in metrics.items():
        print("attaching metric:" + name)
        metric.attach(engine, name)

    return engine
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()
        y_pred = model(x)

        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # This can be helpful for XLA to avoid performance slow down if fetch loss.item() every iteration
        if (config["log_every_iters"] > 0 and
            (engine.state.iteration - 1) % config["log_every_iters"] == 0):
            batch_loss = loss.item()
            engine.state.saved_batch_loss = batch_loss
        else:
            batch_loss = engine.state.saved_batch_loss

        return {"batch loss": batch_loss}

    trainer = Engine(train_step)
    trainer.state.saved_batch_loss = -1.0
    trainer.state_dict_user_keys.append("saved_batch_loss")
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }
    metric_names = ["batch loss"]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(
        ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Esempio n. 13
0
def main(args):
    # TODO(thomasjo): Make this configurable?
    ensure_reproducibility(seed=42)

    # TODO(thomasjo): Make this configurable?
    bands = slice(0, 115)

    dataset = prepare_dataset(args.data_dir, bands)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             pin_memory=True,
                                             num_workers=1)

    # Grab a batch of images that will be used for visualizing epoch results.
    test_batch, _ = next(iter(dataloader))
    test_batch = test_batch[:16]
    test_batch = test_batch.to(device=args.device)

    input_size = test_batch.shape[1:]

    model = Autoencoder(input_size).to(device=args.device)
    optimizer = optim.Adam(model.parameters())
    criterion = nn.BCELoss()

    torchsummary.summary(model,
                         input_size=input_size,
                         batch_size=args.batch_size,
                         device=str(args.device))

    # Create timestamped output directory.
    timestamp = datetime.utcnow().strftime("%Y-%m-%d-%H%M")
    args.output_dir = args.output_dir / timestamp
    args.output_dir.mkdir(parents=True, exist_ok=True)

    def train_step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = prepare_batch(batch, args.device)
        x_hat = model(x)
        loss = criterion(x_hat, x)

        loss.backward()
        optimizer.step()

        return x, x_hat, loss.item()

    trainer = Engine(train_step)
    trainer.logger = setup_logger("trainer")

    @trainer.on(Events.ITERATION_COMPLETED(every=10))
    def log_epoch_metrics(engine: Engine):
        _, _, loss = engine.state.output
        engine.logger.info("Epoch [{}] Iteration [{}/{}] Loss: {:.4f}".format(
            engine.state.epoch,
            engine.state.iteration,
            engine.state.max_epochs * engine.state.epoch_length,
            loss,
        ))

    # Configure model checkpoints.
    checkpoint_handler = ModelCheckpoint(str(args.output_dir),
                                         filename_prefix="ckpt",
                                         n_saved=None)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                              {"model": model})

    # Visualize training progress using test patches.
    @trainer.on(Events.EPOCH_COMPLETED)
    def vizualize_reconstruction(engine: Engine):
        x, x_hat = test_batch, model(test_batch)
        fig, (ax1, ax2) = plt.subplots(1, 2, dpi=300)
        plot_image_grid(ax1, x, band=50, nrow=4)
        plot_image_grid(ax2, x_hat, band=50, nrow=4)
        fig.savefig(args.output_dir / f"epoch-{engine.state.epoch}.png",
                    dpi=300)

    # Start model optimization.
    trainer.run(dataloader, max_epochs=50)
Esempio n. 14
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        token_type_ids = batch["token_type_ids"]
        labels = batch["label"].view(-1, 1)

        if input_ids.device != device:
            input_ids = input_ids.to(device,
                                     non_blocking=True,
                                     dtype=torch.long)
            attention_mask = attention_mask.to(device,
                                               non_blocking=True,
                                               dtype=torch.long)
            token_type_ids = token_type_ids.to(device,
                                               non_blocking=True,
                                               dtype=torch.long)
            labels = labels.to(device, non_blocking=True, dtype=torch.float)

        model.train()

        with autocast(enabled=with_amp):
            y_pred = model(input_ids, attention_mask, token_type_ids)
            loss = criterion(y_pred, labels)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    metric_names = [
        "batch loss",
    ]
    if config["log_every_iters"] == 0:
        # Disable logging training metrics:
        metric_names = None
        config["log_every_iters"] = 15

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=utils.get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        log_every_iters=config["log_every_iters"],
        with_pbars=not config["with_clearml"],
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(
        ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Esempio n. 15
0
                                                              127:128])
    loss.backward()
    optimizer.step()
    return loss.item()


# Create Trainer or Evaluators
trainer = Engine(backprop_step)
train_evaluator = create_supervised_evaluator(model,
                                              metrics=metrics,
                                              device=device)
validation_evaluator = create_supervised_evaluator(model,
                                                   metrics=metrics,
                                                   device=device)

trainer.logger = setup_logger("Trainer")
train_evaluator.logger = setup_logger("Train Evaluator")
validation_evaluator.logger = setup_logger("Validation Evaluator")


# Tensorboard Logger setup below based on pytorch ignite example
# https://github.com/pytorch/ignite/blob/master/examples/contrib/mnist/mnist_with_tensorboard_logger.py
@trainer.on(Events.EPOCH_COMPLETED)
def compute_metrics(engine):
    """Callback to compute metrics on the train and validation data."""
    train_evaluator.run(finetuning_loader)
    validation_evaluator.run(test_loader)
    scheduler.step(validation_evaluator.state.metrics['loss'])


def score_function(engine):
Esempio n. 16
0
def run():
    writer = SummaryWriter()

    CUDA = Config.device
    model = Retriever()
    print(f'Initializing model on {CUDA}')
    model.to(CUDA)
    optimizer = torch.optim.Adam(model.parameters(), lr=Config.LR)
    loss_fn = torch.nn.L1Loss().to(CUDA)
    print(f'Creating sentence transformer')
    encoder = SentenceTransformer(Config.sentence_transformer).to(CUDA)
    for parameter in encoder.parameters():
        parameter.requires_grad = False
    print(f'Loading data')
    if os.path.exists('_full_dump'):
        with open('_full_dump', 'rb') as pin:
            train_loader, train_utts, val_loader, val_utts = pickle.load(pin)
    else:
        data = load_data(Config.data_source)
        train_loader, train_utts, val_loader, val_utts = get_loaders(data, encoder, Config.batch_size)
    
        with open('_full_dump', 'wb') as pout:
            pickle.dump((train_loader, train_utts, val_loader, val_utts), pout, protocol=-1)


    def train_step(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, not_ys, y = batch
        yhat = model(x[0])
        loss = loss_fn(yhat, y)
        gains = loss_fn(not_ys[0], yhat) * Config.negative_weight
        loss -= gains

        loss.backward()
        optimizer.step()
        return loss.item()
    
    def eval_step(engine, batch):
        model.eval()
        with torch.no_grad():
            x, _, y = batch
            yhat = model(x[0])
            return yhat, y
    
    trainer = Engine(train_step)
    trainer.logger = setup_logger('trainer')

    evaluator = Engine(eval_step)
    evaluator.logger = setup_logger('evaluator')
    
    latent_space = BallTree(numpy.array(list(train_utts.keys())))

    l1 = Loss(loss_fn)

    recall = RecallAt(latent_space)

    recall.attach(evaluator, 'recall')
    l1.attach(evaluator, 'l1')
    
    @trainer.on(Events.ITERATION_COMPLETED(every=1000))
    def log_training(engine):
        batch_loss = engine.state.output
        lr = optimizer.param_groups[0]['lr']
        e = engine.state.epoch
        n = engine.state.max_epochs
        i = engine.state.iteration
        print("Epoch {}/{} : {} - batch loss: {}, lr: {}".format(e, n, i, batch_loss, lr))
        writer.add_scalar('Training/loss', batch_loss, i)
    
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        print(f"Training Results - Epoch: {engine.state.epoch} " 
              f" L1: {metrics['l1']:.2f} "
              f" R@1: {metrics['r1']:.2f} "
              f" R@3: {metrics['r3']:.2f} "
              f" R@10: {metrics['r10']:.2f} ")

        for metric, value in metrics.items():
            writer.add_scalar(f'Training/{metric}', value, engine.state.epoch)
        
    #@trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        print(f"Validation Results - Epoch: {engine.state.epoch} "
              f"L1: {metrics['l1']:.2f} " 
              f" R@10: {metrics['r10']:.2f} ")
        for metric, value in metrics.items():
            writer.add_scalar(f'Validation/{metric}', value, engine.state.epoch)
 
    trainer.run(train_loader, max_epochs=Config.max_epochs)

    torch.save(model.state_dict(), Config.checkpoint)
    print(f'Saved checkpoint at {Config.checkpoint}')
    interact(model, encoder, latent_space, train_utts)
Esempio n. 17
0
def run(config):
    train_loader = get_instance(utils, 'dataloader', config, 'train')
    val_loader = get_instance(utils, 'dataloader', config, 'val')

    model = get_instance(models, 'arch', config)

    model = init_model(model, train_loader)
    model, device = ModelPrepper(model, config).out

    loss_fn = get_instance(nn, 'loss_fn', config)

    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = get_instance(torch.optim, 'optimizer', config,
                             trainable_params)

    writer = create_summary_writer(config, model, train_loader)
    batch_size = config['dataloader']['args']['batch_size']

    if config['mode'] == 'eval' or config['resume']:
        model.load_state_dict(torch.load(config['ckpt_path']))

    epoch_length = int(ceil(len(train_loader) / batch_size))
    desc = "ITERATION - loss: {:.2f}"
    pbar = tqdm(initial=0,
                leave=False,
                total=epoch_length,
                desc=desc.format(0))

    def process_batch(engine, batch):
        inputs, outputs = func(batch)
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        preds = model(inputs)
        loss = loss_fn(preds, outputs.to(device))

        a = list(model.parameters())[0].clone()

        loss.backward()
        optimizer.step()

        # check if training is happening
        b = list(model.parameters())[0].clone()
        try:
            assert not torch.allclose(a.data,
                                      b.data), 'Model not updating anymore'
        except AssertionError:
            plot_grad_flow(model.named_parameters())

        return loss.item()

    def predict_on_batch(engine, batch):
        inputs, outputs = func(batch)
        model.eval()
        with torch.no_grad():
            y_pred = model(inputs)

        return inputs, y_pred, outputs.to(device)

    trainer = Engine(process_batch)
    trainer.logger = setup_logger("trainer")
    evaluator = Engine(predict_on_batch)
    evaluator.logger = setup_logger("evaluator")

    if config['task'] == 'actionpred':
        Accuracy(output_transform=lambda x: (x[1], x[2])).attach(
            evaluator, 'val_acc')

    if config['task'] == 'gazepred':
        MeanSquaredError(output_transform=lambda x: (x[1], x[2])).attach(
            evaluator, 'val_MSE')

    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    training_saver = ModelCheckpoint(config['checkpoint_dir'],
                                     filename_prefix='checkpoint_' +
                                     config['task'],
                                     n_saved=1,
                                     atomic=True,
                                     save_as_state_dict=True,
                                     create_dir=True,
                                     require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, training_saver,
                              {'model': model})

    @trainer.on(Events.ITERATION_COMPLETED)
    def tb_log(engine):
        pbar.desc = desc.format(engine.state.output)
        pbar.update(1)
        writer.add_scalar('training/avg_loss', engine.state.metrics['loss'],
                          engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_trainer_logs(engine):
        pbar.refresh()

        avg_loss = engine.state.metrics['loss']
        tqdm.write('Trainer Results - Epoch {} - Avg loss: {:.2f} \n'.format(
            engine.state.epoch, avg_loss))
        viz_param(writer=writer, model=model, global_step=engine.state.epoch)

        pbar.n = pbar.last_print_n = 0

    @evaluator.on(Events.EPOCH_COMPLETED)
    def print_result(engine):
        try:
            print('Evaluator Results - Accuracy {} \n'.format(
                engine.state.metrics['val_acc']))
        except KeyError:
            print('Evaluator Results - MSE {} \n'.format(
                engine.state.metrics['val_MSE']))

    @evaluator.on(Events.ITERATION_COMPLETED)
    def viz_outputs(engine):
        visualize_outputs(writer=writer,
                          state=engine.state,
                          task=config['task'])

    if config['mode'] == 'train':
        trainer.run(train_loader,
                    max_epochs=config['epochs'],
                    epoch_length=epoch_length)

    pbar.close()

    evaluator.run(val_loader,
                  max_epochs=1,
                  epoch_length=int(ceil(len(val_loader) / batch_size)))

    writer.flush()
    writer.close()
Esempio n. 18
0
def create_trainer(model, optimizer, criterion, train_sampler, config, logger, with_clearml):
    device = config.device
    prepare_batch = data.prepare_image_mask

    # Setup trainer
    accumulation_steps = config.get("accumulation_steps", 1)
    model_output_transform = config.get("model_output_transform", lambda x: x)

    with_amp = config.get("with_amp", True)
    scaler = GradScaler(enabled=with_amp)

    def forward_pass(batch):
        model.train()
        x, y = prepare_batch(batch, device=device, non_blocking=True)
        with autocast(enabled=with_amp):
            y_pred = model(x)
            y_pred = model_output_transform(y_pred)
            loss = criterion(y_pred, y) / accumulation_steps
        return loss

    def amp_backward_pass(engine, loss):
        scaler.scale(loss).backward()
        if engine.state.iteration % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

    def hvd_amp_backward_pass(engine, loss):
        scaler.scale(loss).backward()
        optimizer.synchronize()
        with optimizer.skip_synchronize():
            scaler.step(optimizer)
            scaler.update()
        optimizer.zero_grad()

    if idist.backend() == "horovod" and with_amp:
        backward_pass = hvd_amp_backward_pass
    else:
        backward_pass = amp_backward_pass

    def training_step(engine, batch):
        loss = forward_pass(batch)
        output = {"supervised batch loss": loss.item()}
        backward_pass(engine, loss)
        return output

    trainer = Engine(training_step)
    trainer.logger = logger

    output_names = [
        "supervised batch loss",
    ]
    lr_scheduler = config.lr_scheduler

    to_save = {
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
        "trainer": trainer,
        "amp": scaler,
    }

    save_every_iters = config.get("save_every_iters", 1000)

    common.setup_common_training_handlers(
        trainer,
        train_sampler,
        to_save=to_save,
        save_every_iters=save_every_iters,
        save_handler=utils.get_save_handler(config.output_path.as_posix(), with_clearml),
        lr_scheduler=lr_scheduler,
        output_names=output_names,
        with_pbars=not with_clearml,
        log_every_iters=1,
    )

    resume_from = config.get("resume_from", None)
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Esempio n. 19
0
def create_supervised_trainer(model, optimizer, criterion, lr_scheduler,
                              train_sampler, config, logger):
    device = idist.device()

    def _update(engine, batch):

        model.train()

        # x, y = batch[0], batch[1]
        (imgs, targets) = batch

        # if imgs.device != device:
        #    imgs = imgs.to(device, non_blocking=True)
        #    target = target.to(device, non_blocking=True)

        # model.train()
        # (imgs, targets) = batch
        imgs = imgs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        # targets = [target.to(device, non_blocking=True) for target in targets
        #            ]  #if torch.cuda.device_count() >= 1 else targets

        outputs = model(imgs)
        # print(outputs.shape)
        # print(targets.shape)
        loss = criterion(outputs, targets)

        # dist_metrics = [reduce_metric_dict(me) for me in _metrics]

        # Compute gradient
        optimizer.zero_grad()
        # loss = sum(total_loss)
        loss.backward()
        optimizer.step()

        # This can be helpful for XLA to avoid performance slow down if fetch loss.item() every iteration
        acc1, acc5 = utils.accuracy(outputs, targets, topk=(1, 5))
        if config["log_every_iters"] > 0 and (
                engine.state.iteration - 1) % config["log_every_iters"] == 0:
            batch_loss = loss.item()
            engine.state.saved_batch_loss = batch_loss
        else:
            batch_loss = engine.state.saved_batch_loss
        '''
        if idist.get_rank() == 0:
            print(acc1)
            print(acc5)
            print(batch_loss)
        '''
        return {
            "batch loss": batch_loss,
        }

    trainer = Engine(_update)
    trainer.state.saved_batch_loss = -1.0
    trainer.state_dict_user_keys.append("saved_batch_loss")
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix())
        logger.info("Resume from a checkpoint: {}".format(
            checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Esempio n. 20
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()

        with autocast(enabled=with_amp):
            y_pred = model(x)
            loss = criterion(y_pred, y)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Esempio n. 21
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # This can be helpful for XLA to avoid performance slow down if fetch loss.item() every iteration
        if config["log_every_iters"] > 0 and (
                engine.state.iteration - 1) % config["log_every_iters"] == 0:
            batch_loss = loss.item()
            engine.state.saved_batch_loss = batch_loss
        else:
            batch_loss = engine.state.saved_batch_loss

        return {
            "batch loss": batch_loss,
        }

    trainer = Engine(train_step)
    trainer.state.saved_batch_loss = -1.0
    trainer.state_dict_user_keys.append("saved_batch_loss")
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix())
        logger.info("Resume from a checkpoint: {}".format(
            checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Esempio n. 22
0
def create_trainer(
    train_step,
    output_names,
    model,
    ema_model,
    optimizer,
    lr_scheduler,
    supervised_train_loader,
    test_loader,
    cfg,
    logger,
    cta=None,
    unsup_train_loader=None,
    cta_probe_loader=None,
):

    trainer = Engine(train_step)
    trainer.logger = logger

    output_path = os.getcwd()

    to_save = {
        "model": model,
        "ema_model": ema_model,
        "optimizer": optimizer,
        "trainer": trainer,
        "lr_scheduler": lr_scheduler,
    }
    if cta is not None:
        to_save["cta"] = cta

    common.setup_common_training_handlers(
        trainer,
        train_sampler=supervised_train_loader.sampler,
        to_save=to_save,
        save_every_iters=cfg.solver.checkpoint_every,
        output_path=output_path,
        output_names=output_names,
        lr_scheduler=lr_scheduler,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    ProgressBar(persist=False).attach(
        trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED
    )

    unsupervised_train_loader_iter = None
    if unsup_train_loader is not None:
        unsupervised_train_loader_iter = cycle(unsup_train_loader)

    cta_probe_loader_iter = None
    if cta_probe_loader is not None:
        cta_probe_loader_iter = cycle(cta_probe_loader)

    # Setup handler to prepare data batches
    @trainer.on(Events.ITERATION_STARTED)
    def prepare_batch(e):
        sup_batch = e.state.batch
        e.state.batch = {
            "sup_batch": sup_batch,
        }
        if unsupervised_train_loader_iter is not None:
            unsup_batch = next(unsupervised_train_loader_iter)
            e.state.batch["unsup_batch"] = unsup_batch

        if cta_probe_loader_iter is not None:
            cta_probe_batch = next(cta_probe_loader_iter)
            cta_probe_batch["policy"] = [
                deserialize(p) for p in cta_probe_batch["policy"]
            ]
            e.state.batch["cta_probe_batch"] = cta_probe_batch

    # Setup handler to update EMA model
    @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay)
    def update_ema_model(ema_decay):
        # EMA on parametes
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay)

    # Setup handlers for debugging
    if cfg.debug:

        @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100))
        @idist.one_rank_only()
        def log_weights_norms():
            wn = []
            ema_wn = []
            for ema_param, param in zip(ema_model.parameters(), model.parameters()):
                wn.append(torch.mean(param.data))
                ema_wn.append(torch.mean(ema_param.data))

            msg = "\n\nWeights norms"
            msg += "\n- Raw model: {}".format(
                to_list_str(torch.tensor(wn[:10] + wn[-10:]))
            )
            msg += "\n- EMA model: {}\n".format(
                to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:]))
            )
            logger.info(msg)

            rmn = []
            rvar = []
            ema_rmn = []
            ema_rvar = []
            for m1, m2 in zip(model.modules(), ema_model.modules()):
                if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d):
                    rmn.append(torch.mean(m1.running_mean))
                    rvar.append(torch.mean(m1.running_var))
                    ema_rmn.append(torch.mean(m2.running_mean))
                    ema_rvar.append(torch.mean(m2.running_var))

            msg = "\n\nBN buffers"
            msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10])))
            msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10])))
            msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10])))
            msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10])))
            logger.info(msg)

        # TODO: Need to inspect a bug
        # if idist.get_rank() == 0:
        #     from ignite.contrib.handlers import ProgressBar
        #
        #     profiler = BasicTimeProfiler()
        #     profiler.attach(trainer)
        #
        #     @trainer.on(Events.ITERATION_COMPLETED(every=200))
        #     def log_profiling(_):
        #         results = profiler.get_results()
        #         profiler.print_results(results)

    # Setup validation engine
    metrics = {
        "accuracy": Accuracy(),
    }

    if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU):
        metrics.update({
            "precision": Precision(average=False),
            "recall": Recall(average=False),
        })

    eval_kwargs = dict(
        metrics=metrics,
        prepare_batch=sup_prepare_batch,
        device=idist.device(),
        non_blocking=True,
    )

    evaluator = create_supervised_evaluator(model, **eval_kwargs)
    ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs)

    def log_results(epoch, max_epochs, metrics, ema_metrics):
        msg1 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()]
        )
        msg2 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()]
        )
        logger.info(
            "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2)
        )
        if cta is not None:
            logger.info("\n" + stats(cta))

    @trainer.on(
        Events.EPOCH_COMPLETED(every=cfg.solver.validate_every)
        | Events.STARTED
        | Events.COMPLETED
    )
    def run_evaluation():
        evaluator.run(test_loader)
        ema_evaluator.run(test_loader)
        log_results(
            trainer.state.epoch,
            trainer.state.max_epochs,
            evaluator.state.metrics,
            ema_evaluator.state.metrics,
        )

    # setup TB logging
    if idist.get_rank() == 0:
        tb_logger = common.setup_tb_logging(
            output_path,
            trainer,
            optimizers=optimizer,
            evaluators={"validation": evaluator, "ema validation": ema_evaluator},
            log_every_iters=15,
        )
        if cfg.online_exp_tracking.wandb:
            from ignite.contrib.handlers import WandBLogger

            wb_dir = Path("/tmp/output-fixmatch-wandb")
            if not wb_dir.exists():
                wb_dir.mkdir()

            _ = WandBLogger(
                project="fixmatch-pytorch",
                name=cfg.name,
                config=cfg,
                sync_tensorboard=True,
                dir=wb_dir.as_posix(),
                reinit=True,
            )

    resume_from = cfg.solver.resume_from
    if resume_from is not None:
        resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*"))
        if len(resume_from) > 0:
            # get latest
            checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime)
            assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
                checkpoint_fp.as_posix()
            )
            logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
            checkpoint = torch.load(checkpoint_fp.as_posix())
            Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    @trainer.on(Events.COMPLETED)
    def release_all_resources():
        nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter

        if idist.get_rank() == 0:
            tb_logger.close()

        if unsupervised_train_loader_iter is not None:
            unsupervised_train_loader_iter = None

        if cta_probe_loader_iter is not None:
            cta_probe_loader_iter = None

    return trainer
def create_trainer(loader, model, opt, loss_fn, device, args):

    def _update(engine, batch):
        model.train()

        x = batch['x'].to(engine.state.device, non_blocking=True)
        y = batch['y'].to(engine.state.device, non_blocking=True)
        m = batch['m'].to(engine.state.device, non_blocking=True)
        opt.zero_grad()
        y_pred = model(x)

        softmax = nn.Softmax()
        masked_loss = softmax(y_pred)
        #masked_loss = y_pred*m
        loss = loss_fn(masked_loss, y)
        if m.sum().item() / m.numel() > 0.7:
            loss.backward()
            opt.step()
        masked_loss = (masked_loss>0.5).float()
        acc = accuracy_segmentation(masked_loss[:,1,:,:,:],y[:,1,:,:,:])

        return {
            'x': x.detach(),
            'y': y.detach(),
            'm': m.detach(),
            'y_pred': y_pred.detach(),
            'loss': loss.item(),
            'acc' : acc
        }

    def _inference(engine, batch):
        model.eval()

        with th.no_grad():
            x = batch['x'].to(engine.state.device, non_blocking=True)
            y = batch['y'].to(engine.state.device, non_blocking=True)
            m = batch['m'].to(engine.state.device, non_blocking=True)

            y_pred = model(x)
            
            softmax = nn.Softmax(dim=1)
            masked_loss = softmax(y_pred)
            #masked_loss = y_pred*m
            loss = loss_fn(masked_loss, y)
            masked_loss = (masked_loss[-3:]>0.5).float()
            acc = accuracy_segmentation(masked_loss[:,1,:,:,:],y[:,1,:,:,:])

        return {
            'x': x.detach(),
            'y': y.detach(),
            'm': m.detach(),
            'y_pred': y_pred.detach(),
            'loss': loss.item(),
            'acc' : acc
        }


    #wandb.watch(model, log ='all')

    trainer = Engine(_update)
    evaluator = Engine(_inference)

    profiler = BasicTimeProfiler()
    profiler.attach(trainer)
    logdir = args.logdir
    save_ = (not args.devrun) and (not args.nosave)

    # initialize trainer state
    trainer.state.device = device
    trainer.state.hparams = args
    trainer.state.save = save_
    trainer.state.logdir = logdir

    trainer.state.df = defaultdict(dict)
    trainer.state.metrics = dict()
    trainer.state.val_metrics = dict()
    trainer.state.best_metrics = defaultdict(list)
    trainer.state.gradnorm = defaultdict(dict)

    # initialize evaluator state
    evaluator.logger = setup_logger('evaluator')
    evaluator.state.device = device
    evaluator.state.df = defaultdict(dict)
    evaluator.state.metrics = dict()

    pbar = ProgressBar(persist=True)
    ebar = ProgressBar(persist=False)

    pbar.attach(trainer, ['loss'])
    ebar.attach(evaluator, ['loss'])

    pbar.attach(trainer,['acc'])
    ebar.attach(evaluator,['acc'])

    # model summary
    if args.model_summary:
        trainer.add_event_handler(
            Events.STARTED,
            print_model_summary, model
        )

    # terminate on nan
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED,
        TerminateOnNan(lambda x: x['loss'])
    )

    # metrics
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED,
        _metrics
    )

    evaluator.add_event_handler(
        Events.ITERATION_COMPLETED,
        _metrics
    )

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        _metrics_mean
    )

    evaluator.add_event_handler(
        Events.COMPLETED,
        _metrics_mean
    )

    trainer.add_event_handler(
        #Events.STARTED | Events.EPOCH_COMPLETED,
        Events.EPOCH_COMPLETED,
        _evaluate, evaluator, loader
    )

    # logging
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        _log_metrics
    )

    # early stopping
    if args.early_stopping > 0:
        es_p = args.early_stopping
        es_s = lambda engine: -engine.state.metrics['loss']
        evaluator.add_event_handler(
            Events.COMPLETED,
            EarlyStopping(patience=es_p, score_function=es_s, trainer=trainer)
        )

    # lr schedulers
    if args.epoch_length is None:
        el = len(loader['train'])
    else:
        el = args.epoch_length

    if args.lr_scheduler is not None:
        lr_sched = create_lr_scheduler(opt, args, num_steps=el)

        if args.lr_scheduler != 'plateau':
            def _sched_fun(engine):
                lr_sched.step()
        else:
            def _sched_fun(engine):
                e = engine.state.epoch
                v = engine.state.val_metrics[e]['nmse']
                lr_sched.step(v)

        if args.lr_scheduler == 'linearcycle':
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_sched)
        else:
            trainer.add_event_handler(Events.EPOCH_COMPLETED, _sched_fun)

    # FIXME: warmup is modifying opt base_lr -> must create last
    if args.lr_warmup > 0:
        wsched = create_lr_scheduler(opt, args, 'warmup', num_steps=el)
        wsts = wsched.total_steps
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(event_filter=lambda _, i: i <= wsts),
            lambda _: wsched.step()
        )

    # saving
    if save_:
        to_save = {
            'model': model,
            'optimizer': opt,
            'trainer': trainer,
            'evaluator': evaluator
        }

        trainer.add_event_handler(
            Events.EPOCH_COMPLETED,
            Checkpoint(to_save, DiskSaver(logdir), n_saved=3)
        )

        # handler = Checkpoint(
        #     {'model': model},
        #     DiskSaver(logdir),
        #     n_saved = 3,
        #     filename_prefix = 'best',
        #     score_function = lambda engine: -engine.state.metrics['nmae'],
        #     score_name = 'val_nmae',
        # )

        # evaluator.add_event_handler(
        #     Events.COMPLETED,
        #     handler
        # )

        # handler = Checkpoint(
        #     {'model': model},
        #     DiskSaver(logdir),
        #     n_saved = 3,
        #     filename_prefix = 'best',
        #     score_function = lambda engine: -engine.state.metrics['nmse'],
        #     score_name = 'val_nmse',
        # )

        # evaluator.add_event_handler(
        #     Events.COMPLETED,
        #     handler
        # )

        # handler = Checkpoint(
        #     {'model': model},
        #     DiskSaver(logdir),
        #     n_saved = 3,
        #     filename_prefix = 'best',
        #     score_function = lambda engine: engine.state.metrics['R2'],
        #     score_name = 'val_R2',
        # )

        # evaluator.add_event_handler(
        #     Events.COMPLETED,
        #     handler
        # )

        trainer.add_event_handler(
            Events.EPOCH_COMPLETED,
            _save_metrics
        )

        # timer
        trainer.add_event_handler(
            Events.COMPLETED | Events.TERMINATE,
            lambda _: profiler.write_results(logdir + '/time.csv')
        )

    return trainer