Пример #1
0
def test_disksaver_wrong_input(dirname):

    with pytest.raises(ValueError, match=r"Directory path '\S+' is not found"):
        DiskSaver("/tmp/non-existing-folder", create_dir=False)

    previous_fname = os.path.join(dirname,
                                  '{}_{}_{}.pth'.format(_PREFIX, 'obj', 1))
    with open(previous_fname, 'w') as f:
        f.write("test")

    with pytest.raises(ValueError,
                       match=r"Files are already present in the directory"):
        DiskSaver(dirname, require_empty=True)
Пример #2
0
    def _test_existance(atomic, _to_save, expected):

        saver = DiskSaver(dirname, atomic=atomic, create_dir=False, require_empty=False)
        fname = "test.pth"
        try:
            with warnings.catch_warnings():
                # Ignore torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type
                # DummyModel. It won't be checked for correctness upon loading.
                warnings.simplefilter("ignore", category=UserWarning)
                saver(_to_save, fname)
        except Exception:
            pass
        fp = os.path.join(saver.dirname, fname)
        assert os.path.exists(fp) == expected
        if expected:
            saver.remove(fname)
Пример #3
0
def get_save_handler(output_path, with_clearml):
    if with_clearml:
        from ignite.contrib.handlers.clearml_logger import ClearMLSaver

        return ClearMLSaver(dirname=output_path)

    return DiskSaver(output_path)
Пример #4
0
def save_best_model_by_val_score(output_path,
                                 evaluator,
                                 model,
                                 metric_name,
                                 n_saved=3,
                                 trainer=None,
                                 tag="val",
                                 **kwargs):
    """Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric
    (named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).

    Args:
        output_path (str): output path to indicate where to save best models
        evaluator (Engine): evaluation engine used to provide the score
        model (nn.Module): model to store
        metric_name (str): metric name to use for score evaluation. This metric should be present in
            `evaluator.state.metrics`.
        n_saved (int, optional): number of best models to store
        trainer (Engine, optional): trainer engine to fetch the epoch when saving the best model.
        tag (str, optional): score name prefix: `{tag}_{metric_name}`. By default, tag is "val".
        **kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.

    Returns:
        A :class:`~ignite.handlers.checkpoint.Checkpoint` handler.
    """
    return gen_save_best_models_by_val_score(
        save_handler=DiskSaver(dirname=output_path, require_empty=False),
        evaluator=evaluator,
        models=model,
        metric_name=metric_name,
        n_saved=n_saved,
        trainer=trainer,
        tag=tag,
        **kwargs,
    )
Пример #5
0
def get_save_handler(config):
    if config["with_clearml"]:
        from ignite.contrib.handlers.clearml_logger import ClearMLSaver

        return ClearMLSaver(dirname=config["output_path"])

    return DiskSaver(config["output_path"], require_empty=False)
Пример #6
0
def configure_checkpoint_saving(trainer, evaluator, model, optimizer, args):
    to_save = {"model": model, "optimizer": optimizer}
    save_handler = DiskSaver(str(args.output_dir),
                             create_dir=False,
                             require_empty=False)

    # Configure epoch checkpoints.
    interval = 1 if args.dev_mode else min(5, args.max_epochs)
    checkpoint = Checkpoint(
        to_save,
        save_handler,
        n_saved=None,
        global_step_transform=lambda *_: trainer.state.epoch)
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=interval),
                              checkpoint, evaluator)

    # Configure "best score" checkpoints.
    metric_name = "accuracy"
    best_checkpoint = Checkpoint(
        to_save,
        save_handler,
        score_name=metric_name,
        score_function=lambda engine: engine.state.metrics[metric_name],
        filename_prefix="best")
    trainer.add_event_handler(Events.EPOCH_COMPLETED, best_checkpoint,
                              evaluator)
Пример #7
0
def get_save_handler(config):
    if exp_tracking.has_trains:
        from ignite.contrib.handlers.trains_logger import TrainsSaver

        return TrainsSaver(dirname=config.output_path.as_posix())

    return DiskSaver(config.output_path.as_posix())
Пример #8
0
def get_save_handler(config):
    if exp_tracking.has_clearml:
        from ignite.contrib.handlers.clearml_logger import ClearMLSaver

        return ClearMLSaver(dirname=config.output_path.as_posix())

    return DiskSaver(config.output_path.as_posix())
Пример #9
0
    def _test(ext):
        previous_fname = os.path.join(dirname, '{}_{}_{}{}'.format(_PREFIX, 'obj', 1, ext))
        with open(previous_fname, 'w') as f:
            f.write("test")

        with pytest.raises(ValueError, match=r"with extension '.pth' or '.pth.tar' are already present"):
            DiskSaver(dirname, require_empty=True)
Пример #10
0
def get_save_handler(config):
    if config["with_trains"]:
        from ignite.contrib.handlers.trains_logger import TrainsSaver

        return TrainsSaver(dirname=config["output_path"])

    return DiskSaver(config["output_path"], require_empty=False)
Пример #11
0
    def _build_objects(acc_list):

        model = DummyModel().to(device)
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.5)

        def update_fn(engine, batch):
            x = torch.rand((4, 1)).to(device)
            optim.zero_grad()
            y = model(x)
            loss = y.pow(2.0).sum()
            loss.backward()
            if idist.has_xla_support:
                import torch_xla.core.xla_model as xm

                xm.optimizer_step(optim, barrier=True)
            else:
                optim.step()
            lr_scheduler.step()

        trainer = Engine(update_fn)

        evaluator = Engine(lambda e, b: None)
        acc_iter = iter(acc_list)

        @evaluator.on(Events.EPOCH_COMPLETED)
        def setup_result():
            evaluator.state.metrics["accuracy"] = next(acc_iter)

        @trainer.on(Events.EPOCH_COMPLETED)
        def run_eval():
            evaluator.run([0, 1, 2])

        def score_function(engine):
            return engine.state.metrics["accuracy"]

        save_handler = DiskSaver(dirname, create_dir=True, require_empty=False)
        early_stop = EarlyStopping(score_function=score_function,
                                   patience=2,
                                   trainer=trainer)
        evaluator.add_event_handler(Events.COMPLETED, early_stop)

        checkpointer = Checkpoint(
            {
                "trainer": trainer,
                "model": model,
                "optim": optim,
                "lr_scheduler": lr_scheduler,
                "early_stop": early_stop,
            },
            save_handler,
            include_self=True,
            global_step_transform=global_step_from_engine(trainer),
        )
        evaluator.add_event_handler(Events.COMPLETED, checkpointer)

        return trainer, evaluator, model, optim, lr_scheduler, early_stop, checkpointer
Пример #12
0
    def run(self, epochs: int = 1):
        trainer = self.trainer
        train_loader = self.dataloader["train"]
        val_loader = self.dataloader["val"]

        @trainer.on(Events.ITERATION_COMPLETED(every=self.log_interval))
        def log_training_loss(engine):
            length = len(train_loader)
            self.logger.info(f"Epoch[{engine.state.epoch}] "
                             f"Iteration[{engine.state.iteration}/{length}] "
                             f"Loss: {engine.state.output:.2f}")
            self.writer.add_scalar("training/loss", engine.state.output,
                                   engine.state.iteration)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_training_results(engine):
            self.evaluator.run(train_loader)
            metrics = self.evaluator.state.metrics
            avg_accuracy = metrics["accuracy"]
            avg_loss = metrics["loss"]
            self.logger.info(
                f"Training Results - Epoch: {engine.state.epoch}  "
                f"Avg accuracy: {avg_accuracy:.2f}"
                f"Avg loss: {avg_loss:.2f}")
            self.writer.add_scalar("training/avg_loss", avg_loss,
                                   engine.state.epoch)
            self.writer.add_scalar("training/avg_accuracy", avg_accuracy,
                                   engine.state.epoch)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(engine):
            self.evaluator.run(val_loader)
            metrics = self.evaluator.state.metrics
            avg_accuracy = metrics["accuracy"]
            avg_loss = metrics["loss"]
            self.logger.info(
                f"Validation Results - Epoch: {engine.state.epoch} "
                f"Avg accuracy: {avg_accuracy:.2f} "
                f"Avg loss: {avg_loss:.2f}")
            self.writer.add_scalar("valdation/avg_loss", avg_loss,
                                   engine.state.epoch)
            self.writer.add_scalar("valdation/avg_accuracy", avg_accuracy,
                                   engine.state.epoch)

        objects_to_checkpoint = dict(model=self.model,
                                     optimizer=self.optimizer)
        training_checkpoint = Checkpoint(
            to_save=objects_to_checkpoint,
            save_handler=DiskSaver(self.log_dir, require_empty=False),
            n_saved=None,
            global_step_transform=lambda *_: trainer.state.epoch,
        )
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED(every=self.checkpoint_every),
            training_checkpoint,
        )
        trainer.run(train_loader, max_epochs=epochs)
    def setup_checkpoint_saver(self, to_save):
        if self.hparams.checkpoint_params is not None:
            from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine

            handler = Checkpoint(to_save, DiskSaver(self.hparams.checkpoint_params["save_dir"], require_empty=False), n_saved=self.hparams.checkpoint_params["n_saved"],
                                filename_prefix=self.hparams.checkpoint_params["prefix_name"], score_function=self.score_function, score_name="score", 
                                global_step_transform=global_step_from_engine(self.trainer))

            self.evaluator.add_event_handler(Events.COMPLETED, handler)
Пример #14
0
def create_supervised_trainer_skipgram(model,
                                       optimizer,
                                       prepare_batch,
                                       metrics={},
                                       device=None,
                                       log_dir='output/log/',
                                       checkpoint_dir='output/checkpoints/',
                                       checkpoint_every=None,
                                       tensorboard_every=50) -> Engine:
    def _prepare_batch(batch):

        return batch

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()

        batch = _prepare_batch(batch)
        batch_loss = model._loss(batch)
        loss = batch_loss.mean()

        loss.backward()
        optimizer.step()

        return {'loss': loss.item(), 'y_pred': scores, 'y': target}

    model.to(device)
    engine = Engine(_update)

    # Metrics
    RunningAverage(output_transform=lambda x: x['loss']).attach(
        engine, 'average_loss')

    # TQDM
    pbar = ProgressBar(persist=True, )
    pbar.attach(engine, ['average_loss'])

    # Checkpoint saving
    # to_save = {'model': model, 'optimizer': optimizer, 'engine': engine}
    final_checkpoint_handler = Checkpoint({'model': model},
                                          DiskSaver(checkpoint_dir,
                                                    create_dir=True),
                                          n_saved=None,
                                          filename_prefix='final')

    engine.add_event_handler(Events.COMPLETED, final_checkpoint_handler)

    @engine.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        metrics = engine.state.metrics
        print(f"Epoch results - Avg loss: {metrics['average_loss']:.6f},"
              f" Accuracy: {metrics['accuracy']:.6f},"
              f" Non-Pad-Accuracy: {metrics['non_pad_accuracy']:.6f}")

    return engine
Пример #15
0
    def __init__(self,
                 model: nn.Module,
                 criterion: Callable,
                 optimizers: Union[Optimizer, Sequence[Optimizer]],
                 lr_schedulers: Union[_LRScheduler, Sequence[_LRScheduler]],
                 metrics: Dict[str, Metric],
                 test_metrics: Dict[str, Metric],
                 save_path: Union[Path, str] = ".",
                 fp16: bool = False,
                 lr_step_on_iter: bool = False,
                 device: Optional[str] = None):

        # Check Arguments
        if not isinstance(optimizers, Sequence):
            optimizers = [optimizers]
        if not isinstance(lr_schedulers, Sequence):
            lr_schedulers = [lr_schedulers]
        if device is None:
            device = 'cuda' if CUDA else 'cpu'
        save_path = fmt_path(save_path)
        model.to(device)

        if fp16:
            from apex import amp
            model, optimizer = amp.initialize(model,
                                              optimizers,
                                              opt_level="O1",
                                              verbosity=0)

        # Set Arguments

        self.model = model
        self.criterion = criterion
        self.optimizers = optimizers
        self.lr_schedulers = lr_schedulers
        self.metrics = metrics
        self.test_metrics = test_metrics
        self.save_path = save_path
        self.fp16 = fp16
        self.lr_step_on_iter = lr_step_on_iter
        self.device = device

        self.log_path = self.save_path / "runs"
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        self.writer = SummaryWriter(str(self.log_path / current_time))

        self.train_engine = self._create_train_engine()
        self.eval_engine = self._create_eval_engine()
        saver = DiskSaver(str(self.save_path),
                          create_dir=True,
                          require_empty=False)
        self.checkpoint_handler = Checkpoint(self.to_save(), saver)

        self._traier_state = TrainerState.INITIALIZED
Пример #16
0
def test_setup_common_training_handlers_using_save_handler(dirname, capsys):

    save_handler = DiskSaver(dirname=dirname, require_empty=False)
    _test_setup_common_training_handlers(dirname=None, device="cpu", save_handler=save_handler)

    # Check epoch-wise pbar
    captured = capsys.readouterr()
    out = captured.err.split("\r")
    out = list(map(lambda x: x.strip(), out))
    out = list(filter(None, out))
    assert "Epoch" in out[-1] or "Epoch" in out[-2], f"{out[-2]}, {out[-1]}"
Пример #17
0
def save_checkpoint(trainer, evaluator, to_save, score_function, save_dir,
                    n_saved, prefix_name):
    if save_dir is not None and score_function is not None:
        handler = Checkpoint(
            to_save,
            DiskSaver(save_dir, require_empty=False),
            n_saved=n_saved,
            filename_prefix=prefix_name,
            score_function=score_function,
            score_name="score",
            global_step_transform=global_step_from_engine(trainer))

        evaluator.add_event_handler(Events.COMPLETED, handler)
Пример #18
0
def test_disksaver_wrong_input(dirname):

    with pytest.raises(ValueError, match=r"Directory path '\S+' is not found"):
        DiskSaver("/tmp/non-existing-folder", create_dir=False)

    def _test(ext):
        previous_fname = os.path.join(dirname, "{}_{}_{}{}".format(_PREFIX, "obj", 1, ext))
        with open(previous_fname, "w") as f:
            f.write("test")

        with pytest.raises(ValueError, match=r"with extension '.pt' are already present"):
            DiskSaver(dirname, require_empty=True)

    _test(".pt")
Пример #19
0
def setup_checkpoints(trainer, obj_to_save, epoch_length, conf):
    # type: (Engine, Dict[str, Any], int, DictConfig) -> None
    cp = conf.checkpoints
    save_path = cp.get('save_dir', os.getcwd())
    logging.info("Saving checkpoints to {}".format(save_path))
    max_cp = max(int(cp.get('max_checkpoints', 1)), 1)
    save = DiskSaver(save_path, create_dir=True, require_empty=True)
    make_checkpoint = Checkpoint(obj_to_save, save, n_saved=max_cp)
    cp_iter = cp.interval_iteration
    cp_epoch = cp.interval_epoch
    if cp_iter > 0:
        save_event = Events.ITERATION_COMPLETED(every=cp_iter)
        trainer.add_event_handler(save_event, make_checkpoint)
    if cp_epoch > 0:
        if cp_iter < 1 or epoch_length % cp_iter:
            save_event = Events.EPOCH_COMPLETED(every=cp_epoch)
            trainer.add_event_handler(save_event, make_checkpoint)
Пример #20
0
def save_best_model_by_val_score(output_path,
                                 evaluator,
                                 model,
                                 metric_name,
                                 n_saved=3,
                                 trainer=None,
                                 tag="val"):
    """Method adds a handler to `evaluator` to save best models based on the score (named by `metric_name`)
    provided by `evaluator`.

    Args:
        output_path (str): output path to indicate where to save best models
        evaluator (Engine): evaluation engine used to provide the score
        model (nn.Module): model to store
        metric_name (str): metric name to use for score evaluation. This metric should be present in
            `evaluator.state.metrics`.
        n_saved (int, optional): number of best models to store
        trainer (Engine, optional): trainer engine to fetch the epoch when saving the best model.
        tag (str, optional): score name prefix: `{tag}_{metric_name}`. By default, tag is "val".

    Returns:
        A :class:`~ignite.handlers.checkpoint.Checkpoint` handler.
    """
    global_step_transform = None
    if trainer is not None:
        global_step_transform = global_step_from_engine(trainer)

    best_model_handler = Checkpoint(
        {
            "model": model,
        },
        DiskSaver(dirname=output_path, require_empty=False),
        filename_prefix="best",
        n_saved=n_saved,
        global_step_transform=global_step_transform,
        score_name="{}_{}".format(tag, metric_name.lower()),
        score_function=get_default_score_fn(metric_name),
    )
    evaluator.add_event_handler(
        Events.COMPLETED,
        best_model_handler,
    )

    return best_model_handler
Пример #21
0
def set_handlers(trainer: Engine, evaluator: Engine, valloader: DataLoader,
                 model: nn.Module, optimizer: optim.Optimizer,
                 args: Namespace) -> None:
    ROC_AUC(
        output_transform=lambda output: (output.logit, output.label)).attach(
            engine=evaluator, name='roc_auc')
    Accuracy(output_transform=lambda output: (
        (output.logit > 0).long(), output.label)).attach(engine=evaluator,
                                                         name='accuracy')
    Loss(loss_fn=nn.BCEWithLogitsLoss(),
         output_transform=lambda output:
         (output.logit, output.label.float())).attach(engine=evaluator,
                                                      name='loss')

    ProgressBar(persist=True, desc='Epoch').attach(
        engine=trainer, output_transform=lambda output: {'loss': output.loss})
    ProgressBar(persist=False, desc='Eval').attach(engine=evaluator)
    ProgressBar(persist=True, desc='Eval').attach(
        engine=evaluator,
        metric_names=['roc_auc', 'accuracy', 'loss'],
        event_name=Events.EPOCH_COMPLETED,
        closing_event_name=Events.COMPLETED)

    @trainer.on(Events.ITERATION_COMPLETED(every=args.evaluation_interval))
    def _evaluate(trainer: Engine):
        evaluator.run(valloader, max_epochs=1)

    evaluator.add_event_handler(
        event_name=Events.EPOCH_COMPLETED,
        handler=Checkpoint(
            to_save={
                'model': model,
                'optimizer': optimizer,
                'trainer': trainer
            },
            save_handler=DiskSaver(dirname=args.checkpoint_dir,
                                   atomic=True,
                                   create_dir=True,
                                   require_empty=False),
            filename_prefix='best',
            score_function=lambda engine: engine.state.metrics['roc_auc'],
            score_name='val_roc_auc',
            n_saved=1,
            global_step_transform=global_step_from_engine(trainer)))
Пример #22
0
    def __init__(self,
                 model,
                 criterion,
                 optimizer_model,
                 optimizer_arch,
                 lr_scheduler,
                 metrics=None,
                 test_metrics=None,
                 save_path="checkpoints",
                 device=None):
        self.device = device or ('cuda' if CUDA else 'cpu')
        model.to(self.device)

        self.model = model
        self.criterion = criterion
        self.optimizer_model = optimizer_model
        self.optimizer_arch = optimizer_arch
        self.lr_scheduler = lr_scheduler
        self._output_transform = get(["y_pred", "y"])
        self.metrics = metrics or {
            "loss": TrainLoss(),
            "acc": Accuracy(self._output_transform),
        }
        self.test_metrics = test_metrics or {
            "loss": Loss(self.criterion, self._output_transform),
            "acc": Accuracy(self._output_transform),
        }
        self.save_path = save_path
        self._log_path = os.path.join(self.save_path, "runs")

        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        log_dir = os.path.join(self._log_path, current_time)
        self.writer = SummaryWriter(log_dir)

        self.train_engine = self._create_train_engine()
        self.eval_engine = self._create_eval_engine()
        self.checkpoint_handler = Checkpoint(
            self.to_save(),
            DiskSaver(self.save_path, create_dir=True, require_empty=False))
Пример #23
0
    def save(self):
        train_engine = self._create_train_engine()
        eval_engine = self._create_eval_engine()

        train_engine.load_state_dict(self._train_engine_state)
        eval_engine.load_state_dict(self._eval_engine_state)

        def global_step_transform(engine, event_name):
            return engine.state.epoch

        saver = DiskSaver(str(self.save_path),
                          create_dir=True,
                          require_empty=False)
        to_save = {
            **self.to_save(), "train_engine": train_engine,
            "eval_engine": eval_engine
        }

        checkpoint_handler = Checkpoint(
            to_save,
            saver,
            n_saved=1,
            global_step_transform=global_step_transform)
        checkpoint_handler(train_engine)
Пример #24
0
    trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_results)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)

    #### Checkpoint

    # to_save = {'{}_{}'.format(p_name, m_name): model,
    #           'optimizer': optimizer,
    #           'lr_scheduler': scheduler
    #           }

    to_save = {'api_{}'.format(m_name): model}

    cp_handler = Checkpoint(to_save,
                            DiskSaver('../models/',
                                      create_dir=True,
                                      require_empty=False),
                            filename_prefix='best',
                            score_function=score_function_loss,
                            score_name='val_loss')

    validation_evaluator.add_event_handler(Events.COMPLETED, cp_handler)
    #trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), cp_handler)

    # checkpointer = ModelCheckpoint('../models/', '{}'.format(p_name), create_dir=True, save_as_state_dict=True, require_empty=False)

    # trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
    trainer.run(train_iterator, max_epochs=4)
else:
    print('Runing saved model...')
    #run_on_test(cur_model, p_name, m_name, pred_iterator)
Пример #25
0
def main(args):
    fix_seeds()
    # if os.path.exists('./logs'):
    #     shutil.rmtree('./logs')
    # os.mkdir('./logs')
    # writer = SummaryWriter(log_dir='./logs')
    vis = visdom.Visdom()
    val_avg_loss_window = create_plot_window(vis,
                                             '#Epochs',
                                             'Loss',
                                             'Average Loss',
                                             legend=['Train', 'Val'])
    val_avg_accuracy_window = create_plot_window(vis,
                                                 '#Epochs',
                                                 'Accuracy',
                                                 'Average Accuracy',
                                                 legend=['Val'])
    size = (args.height, args.width)
    train_transform = transforms.Compose([
        transforms.Resize(size),
        # transforms.RandomResizedCrop(size=size, scale=(0.5, 1)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(10,
                                translate=(0.1, 0.1),
                                scale=(0.8, 1.2),
                                resample=PIL.Image.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    val_transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_dataset = TextDataset(args.data_path,
                                'train.txt',
                                size=args.train_size,
                                transform=train_transform)
    val_dataset = TextDataset(args.data_path,
                              'val.txt',
                              size=args.val_size,
                              transform=val_transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              shuffle=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            num_workers=args.workers,
                            shuffle=False)

    model = models.resnet18(pretrained=False)
    model.fc = nn.Linear(512, 16)

    model.load_state_dict(torch.load(args.resume_from)['model'])

    device = 'cpu'
    if args.cuda:
        device = 'cuda'
    print(device)
    metrics = {'accuracy': Accuracy(), 'loss': Loss(criterion)}
    evaluator = create_supervised_evaluator(model, metrics, device=device)

    @trainer.on(Events.ITERATION_COMPLETED)
    def lr_step(engine):
        if model.training:
            scheduler.step()

    global pbar, desc
    pbar, desc = None, None

    @trainer.on(Events.EPOCH_STARTED)
    def create_train_pbar(engine):
        global desc, pbar
        if pbar is not None:
            pbar.close()
        desc = 'Train iteration - loss: {:.4f} - lr: {:.4f}'
        pbar = tqdm(initial=0,
                    leave=False,
                    total=len(train_loader),
                    desc=desc.format(0, lr))

    @trainer.on(Events.EPOCH_COMPLETED)
    def create_val_pbar(engine):
        global desc, pbar
        if pbar is not None:
            pbar.close()
        desc = 'Validation iteration - loss: {:.4f}'
        pbar = tqdm(initial=0,
                    leave=False,
                    total=len(val_loader),
                    desc=desc.format(0))

    # desc_val = 'Validation iteration - loss: {:.4f}'
    # pbar_val = tqdm(initial=0, leave=False, total=len(val_loader), desc=desc_val.format(0))

    log_interval = 1
    e = Events.ITERATION_COMPLETED(every=log_interval)

    train_losses = []

    @trainer.on(e)
    def log_training_loss(engine):
        lr = optimizer.param_groups[0]['lr']
        train_losses.append(engine.state.output)
        pbar.desc = desc.format(engine.state.output, lr)
        pbar.update(log_interval)
        # writer.add_scalar("training/loss", engine.state.output, engine.state.iteration)
        # writer.add_scalar("lr", lr, engine.state.iteration)

    @evaluator.on(e)
    def log_validation_loss(engine):
        label = engine.state.batch[1].to(device)
        output = engine.state.output[0]
        pbar.desc = desc.format(criterion(output, label))
        pbar.update(log_interval)

    # if args.resume_from is not None:
    #     @trainer.on(Events.STARTED)
    #     def _(engine):
    #         pbar.n = engine.state.iteration

    # @trainer.on(Events.EPOCH_COMPLETED(every=1))
    # def log_train_results(engine):
    #     evaluator.run(train_loader) # eval on train set to check for overfitting
    #     metrics = evaluator.state.metrics
    #     avg_accuracy = metrics['accuracy']
    #     avg_nll = metrics['loss']
    #     tqdm.write(
    #         "Train Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
    #         .format(engine.state.epoch, avg_accuracy, avg_nll))
    #     pbar.n = pbar.last_print_n = 0

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        pbar.refresh()
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['loss']
        tqdm.write(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))
        # pbar.n = pbar.last_print_n = 0

        # writer.add_scalars("avg losses", {"train": statistics.mean(train_losses),
        #                                   "valid": avg_nll}, engine.state.epoch)
        # # writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        # writer.add_scalar("avg_accuracy", avg_accuracy, engine.state.epoch)
        vis.line(X=np.array([engine.state.epoch]),
                 Y=np.array([avg_accuracy]),
                 win=val_avg_accuracy_window,
                 update='append')
        vis.line(X=np.column_stack(
            (np.array([engine.state.epoch]), np.array([engine.state.epoch]))),
                 Y=np.column_stack((np.array([statistics.mean(train_losses)]),
                                    np.array([avg_nll]))),
                 win=val_avg_loss_window,
                 update='append',
                 opts=dict(legend=['Train', 'Val']))
        del train_losses[:]

    objects_to_checkpoint = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "scheduler": scheduler
    }
    training_checkpoint = Checkpoint(to_save=objects_to_checkpoint,
                                     save_handler=DiskSaver(
                                         args.snapshot_dir,
                                         require_empty=False))
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                              training_checkpoint)
    if args.resume_from not in [None, '']:
        tqdm.write("Resume from a checkpoint: {}".format(args.resume_from))
        checkpoint = torch.load(args.resume_from)
        Checkpoint.load_objects(to_load=objects_to_checkpoint,
                                checkpoint=checkpoint)

    try:
        trainer.run(train_loader, max_epochs=args.epochs)
        pbar.close()
    except Exception as e:
        import traceback
        print(traceback.format_exc())
Пример #26
0
def run(
    train_batch_size,
    val_batch_size,
    epochs,
    lr,
    momentum,
    log_interval,
    log_dir,
    checkpoint_every,
    resume_from,
    crash_iteration=-1,
    deterministic=False,
):
    # Setup seed to have same model's initialization:
    manual_seed(75)

    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    writer = SummaryWriter(log_dir=log_dir)
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    criterion = nn.NLLLoss()
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.5)

    # Setup trainer and evaluator
    if deterministic:
        tqdm.write("Setup deterministic trainer")
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device,
                                        deterministic=deterministic)

    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                "accuracy": Accuracy(),
                                                "nll": Loss(criterion)
                                            },
                                            device=device)

    # Apply learning rate scheduling
    @trainer.on(Events.EPOCH_COMPLETED)
    def lr_step(engine):
        lr_scheduler.step()

    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=f"Epoch {0} - loss: {0:.4f} - lr: {lr:.4f}")

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        lr = optimizer.param_groups[0]["lr"]
        pbar.desc = f"Epoch {engine.state.epoch} - loss: {engine.state.output:.4f} - lr: {lr:.4f}"
        pbar.update(log_interval)
        writer.add_scalar("training/loss", engine.state.output,
                          engine.state.iteration)
        writer.add_scalar("lr", lr, engine.state.iteration)

    if crash_iteration > 0:

        @trainer.on(Events.ITERATION_COMPLETED(once=crash_iteration))
        def _(engine):
            raise Exception(f"STOP at {engine.state.iteration}")

    if resume_from is not None:

        @trainer.on(Events.STARTED)
        def _(engine):
            pbar.n = engine.state.iteration % engine.state.epoch_length

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        tqdm.write(
            f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("training/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    # Compute and log validation metrics
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        tqdm.write(
            f"Validation Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        pbar.n = pbar.last_print_n = 0
        writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("valdation/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    # Setup object to checkpoint
    objects_to_checkpoint = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    training_checkpoint = Checkpoint(
        to_save=objects_to_checkpoint,
        save_handler=DiskSaver(log_dir, require_empty=False),
        n_saved=None,
        global_step_transform=lambda *_: trainer.state.epoch,
    )
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=checkpoint_every),
                              training_checkpoint)

    # Setup logger to print and dump into file: model weights, model grads and data stats
    # - first 3 iterations
    # - 4 iterations after checkpointing
    # This helps to compare resumed training with checkpointed training
    def log_event_filter(e, event):
        if event in [1, 2, 3]:
            return True
        elif 0 <= (event % (checkpoint_every * e.state.epoch_length)) < 5:
            return True
        return False

    fp = Path(log_dir) / ("run.log"
                          if resume_from is None else "resume_run.log")
    fp = fp.as_posix()
    for h in [log_data_stats, log_model_weights, log_model_grads]:
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(event_filter=log_event_filter),
            h,
            model=model,
            fp=fp)

    if resume_from is not None:
        tqdm.write(f"Resume from the checkpoint: {resume_from}")
        checkpoint = torch.load(resume_from)
        Checkpoint.load_objects(to_load=objects_to_checkpoint,
                                checkpoint=checkpoint)

    try:
        # Synchronize random states
        manual_seed(15)
        trainer.run(train_loader, max_epochs=epochs)
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    pbar.close()
    writer.close()
Пример #27
0
        print(timestamp + " Validation set Results - Epoch: {}  Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
              .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['loss']))
        experiment.log_metric("valid_mae", metrics['mae'])
        experiment.log_metric("valid_mse", metrics['mse'])
        experiment.log_metric("valid_loss", metrics['loss'])

        # timer
        experiment.log_metric("evaluate_timer", evaluate_timer.value())
        print("evaluate_timer ", evaluate_timer.value())

    def checkpoint_valid_mae_score_function(engine):
        score = engine.state.metrics['mae']
        return -score


    # docs on save and load
    to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer}
    save_handler = Checkpoint(to_save, DiskSaver('saved_model/' + args.task_id, create_dir=True, atomic=True),
                              filename_prefix=args.task_id,
                              n_saved=5)

    save_handler_best = Checkpoint(to_save, DiskSaver('saved_model_best/' + args.task_id, create_dir=True, atomic=True),
                              filename_prefix=args.task_id, score_name="valid_mae", score_function=checkpoint_valid_mae_score_function,
                              n_saved=5)

    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=5), save_handler)
    evaluator_validate.add_event_handler(Events.EPOCH_COMPLETED(every=1), save_handler_best)


    trainer.run(train_loader, max_epochs=args.epochs)
Пример #28
0
def main():
    # region Setup
    conf = parse_args()
    setup_seeds(conf.session.seed)
    tb_logger, tb_img_logger, json_logger = setup_all_loggers(conf)
    logger.info("Parsed configuration:\n" +
                pyaml.dump(OmegaConf.to_container(conf),
                           safe=True,
                           sort_dicts=False,
                           force_embed=True))

    # region Predicate classification engines
    datasets, dataset_metadata = build_datasets(conf.dataset)
    dataloaders = build_dataloaders(conf, datasets)

    model = build_model(conf.model,
                        dataset_metadata["train"]).to(conf.session.device)
    criterion = PredicateClassificationCriterion(conf.losses)

    pred_class_trainer = Trainer(pred_class_training_step, conf)
    pred_class_trainer.model = model
    pred_class_trainer.criterion = criterion
    pred_class_trainer.optimizer, scheduler = build_optimizer_and_scheduler(
        conf.optimizer, pred_class_trainer.model)

    pred_class_validator = Validator(pred_class_validation_step, conf)
    pred_class_validator.model = model
    pred_class_validator.criterion = criterion

    pred_class_tester = Validator(pred_class_validation_step, conf)
    pred_class_tester.model = model
    pred_class_tester.criterion = criterion
    # endregion

    if "resume" in conf:
        checkpoint = Path(conf.resume.checkpoint).expanduser().resolve()
        logger.debug(f"Resuming checkpoint from {checkpoint}")
        Checkpoint.load_objects(
            {
                "model": pred_class_trainer.model,
                "optimizer": pred_class_trainer.optimizer,
                "scheduler": scheduler,
                "trainer": pred_class_trainer,
            },
            checkpoint=torch.load(checkpoint,
                                  map_location=conf.session.device),
        )
        logger.info(f"Resumed from {checkpoint}, "
                    f"epoch {pred_class_trainer.state.epoch}, "
                    f"samples {pred_class_trainer.global_step()}")
    # endregion

    # region Predicate classification training callbacks
    def increment_samples(trainer: Trainer):
        images = trainer.state.batch[0]
        trainer.state.samples += len(images)

    pred_class_trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                         increment_samples)

    ProgressBar(persist=True, desc="Pred class train").attach(
        pred_class_trainer, output_transform=itemgetter("losses"))

    tb_logger.attach(
        pred_class_trainer,
        OptimizerParamsHandler(
            pred_class_trainer.optimizer,
            param_name="lr",
            tag="z",
            global_step_transform=pred_class_trainer.global_step,
        ),
        Events.EPOCH_STARTED,
    )

    pred_class_trainer.add_event_handler(
        Events.ITERATION_COMPLETED,
        PredicateClassificationMeanAveragePrecisionBatch())
    pred_class_trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                         RecallAtBatch(sizes=(5, 10)))

    tb_logger.attach(
        pred_class_trainer,
        OutputHandler(
            "train",
            output_transform=lambda o: {
                **o["losses"],
                "pc/mAP": o["pc/mAP"].mean().item(),
                **{k: r.mean().item()
                   for k, r in o["recalls"].items()},
            },
            global_step_transform=pred_class_trainer.global_step,
        ),
        Events.ITERATION_COMPLETED,
    )

    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        log_metrics,
        "Predicate classification training",
        "train",
        json_logger=None,
        tb_logger=tb_logger,
        global_step_fn=pred_class_trainer.global_step,
    )
    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        PredicateClassificationLogger(
            grid=(2, 3),
            tag="train",
            logger=tb_img_logger.writer,
            metadata=dataset_metadata["train"],
            global_step_fn=pred_class_trainer.global_step,
        ),
    )
    tb_logger.attach(
        pred_class_trainer,
        EpochHandler(
            pred_class_trainer,
            tag="z",
            global_step_transform=pred_class_trainer.global_step,
        ),
        Events.EPOCH_COMPLETED,
    )

    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        lambda _: pred_class_validator.run(dataloaders["val"]))
    # endregion

    # region Predicate classification validation callbacks
    ProgressBar(persist=True,
                desc="Pred class val").attach(pred_class_validator)

    if conf.losses["bce"]["weight"] > 0:
        Average(output_transform=lambda o: o["losses"]["loss/bce"]).attach(
            pred_class_validator, "loss/bce")
    if conf.losses["rank"]["weight"] > 0:
        Average(output_transform=lambda o: o["losses"]["loss/rank"]).attach(
            pred_class_validator, "loss/rank")
    Average(output_transform=lambda o: o["losses"]["loss/total"]).attach(
        pred_class_validator, "loss/total")

    PredicateClassificationMeanAveragePrecisionEpoch(
        itemgetter("target", "output")).attach(pred_class_validator, "pc/mAP")
    RecallAtEpoch((5, 10),
                  itemgetter("target",
                             "output")).attach(pred_class_validator,
                                               "pc/recall_at")

    pred_class_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        lambda val_engine: scheduler.step(val_engine.state.metrics["loss/total"
                                                                   ]),
    )
    pred_class_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        log_metrics,
        "Predicate classification validation",
        "val",
        json_logger,
        tb_logger,
        pred_class_trainer.global_step,
    )
    pred_class_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        PredicateClassificationLogger(
            grid=(2, 3),
            tag="val",
            logger=tb_img_logger.writer,
            metadata=dataset_metadata["val"],
            global_step_fn=pred_class_trainer.global_step,
        ),
    )
    pred_class_validator.add_event_handler(
        Events.COMPLETED,
        EarlyStopping(
            patience=conf.session.early_stopping.patience,
            score_function=lambda val_engine: -val_engine.state.metrics[
                "loss/total"],
            trainer=pred_class_trainer,
        ),
    )
    pred_class_validator.add_event_handler(
        Events.COMPLETED,
        Checkpoint(
            {
                "model": pred_class_trainer.model,
                "optimizer": pred_class_trainer.optimizer,
                "scheduler": scheduler,
                "trainer": pred_class_trainer,
            },
            DiskSaver(
                Path(conf.checkpoint.folder).expanduser().resolve() /
                conf.fullname),
            score_function=lambda val_engine: val_engine.state.metrics[
                "pc/recall_at_5"],
            score_name="pc_recall_at_5",
            n_saved=conf.checkpoint.keep,
            global_step_transform=pred_class_trainer.global_step,
        ),
    )
    # endregion

    if "test" in conf.dataset:
        # region Predicate classification testing callbacks
        if conf.losses["bce"]["weight"] > 0:
            Average(
                output_transform=lambda o: o["losses"]["loss/bce"],
                device=conf.session.device,
            ).attach(pred_class_tester, "loss/bce")
        if conf.losses["rank"]["weight"] > 0:
            Average(
                output_transform=lambda o: o["losses"]["loss/rank"],
                device=conf.session.device,
            ).attach(pred_class_tester, "loss/rank")
        Average(
            output_transform=lambda o: o["losses"]["loss/total"],
            device=conf.session.device,
        ).attach(pred_class_tester, "loss/total")

        PredicateClassificationMeanAveragePrecisionEpoch(
            itemgetter("target", "output")).attach(pred_class_tester, "pc/mAP")
        RecallAtEpoch((5, 10),
                      itemgetter("target",
                                 "output")).attach(pred_class_tester,
                                                   "pc/recall_at")

        ProgressBar(persist=True,
                    desc="Pred class test").attach(pred_class_tester)

        pred_class_tester.add_event_handler(
            Events.EPOCH_COMPLETED,
            log_metrics,
            "Predicate classification test",
            "test",
            json_logger,
            tb_logger,
            pred_class_trainer.global_step,
        )
        pred_class_tester.add_event_handler(
            Events.EPOCH_COMPLETED,
            PredicateClassificationLogger(
                grid=(2, 3),
                tag="test",
                logger=tb_img_logger.writer,
                metadata=dataset_metadata["test"],
                global_step_fn=pred_class_trainer.global_step,
            ),
        )
        # endregion

    # region Run
    log_effective_config(conf, pred_class_trainer, tb_logger)
    if not ("resume" in conf and conf.resume.test_only):
        max_epochs = conf.session.max_epochs
        if "resume" in conf:
            max_epochs += pred_class_trainer.state.epoch
        pred_class_trainer.run(
            dataloaders["train"],
            max_epochs=max_epochs,
            seed=conf.session.seed,
            epoch_length=len(dataloaders["train"]),
        )

    if "test" in conf.dataset:
        pred_class_tester.run(dataloaders["test"])

    add_session_end(tb_logger.writer, "SUCCESS")
    tb_logger.close()
    tb_img_logger.close()
Пример #29
0
def _setup_common_training_handlers(
    trainer: Engine,
    to_save: Optional[Mapping] = None,
    save_every_iters: int = 1000,
    output_path: Optional[str] = None,
    lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
    with_gpu_stats: bool = False,
    output_names: Optional[Iterable[str]] = None,
    with_pbars: bool = True,
    with_pbar_on_iters: bool = True,
    log_every_iters: int = 100,
    stop_on_nan: bool = True,
    clear_cuda_cache: bool = True,
    save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
    **kwargs: Any,
) -> None:
    if output_path is not None and save_handler is not None:
        raise ValueError(
            "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
        )

    if stop_on_nan:
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(
                Events.ITERATION_COMPLETED,
                lambda engine: cast(_LRScheduler, lr_scheduler).step())
        elif isinstance(lr_scheduler, LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    if torch.cuda.is_available() and clear_cuda_cache:
        trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:

        if output_path is None and save_handler is None:
            raise ValueError(
                "If to_save argument is provided then output_path or save_handler arguments should be also defined"
            )
        if output_path is not None:
            save_handler = DiskSaver(dirname=output_path, require_empty=False)

        checkpoint_handler = Checkpoint(to_save,
                                        cast(Union[Callable, BaseSaveHandler],
                                             save_handler),
                                        filename_prefix="training",
                                        **kwargs)
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(every=save_every_iters),
            checkpoint_handler)

    if with_gpu_stats:
        GpuInfo().attach(
            trainer,
            name="gpu",
            event_name=Events.ITERATION_COMPLETED(
                every=log_every_iters)  # type: ignore[arg-type]
        )

    if output_names is not None:

        def output_transform(x: Any, index: int, name: str) -> Any:
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, (torch.Tensor, numbers.Number)):
                return x
            else:
                raise TypeError(
                    "Unhandled type of update_function's output. "
                    f"It should either mapping or sequence, but given {type(x)}"
                )

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform,
                                                    index=i,
                                                    name=n),
                           epoch_bound=False).attach(trainer, n)

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(
                trainer,
                metric_names="all",
                event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

        ProgressBar(persist=True,
                    bar_format="").attach(trainer,
                                          event_name=Events.EPOCH_STARTED,
                                          closing_event_name=Events.COMPLETED)
def train(model, optimizer, loss_fn, train_loader, val_loader,
          log_dir, device, epochs, log_interval,
          load_weight_path=None, save_graph=False):
    """Training logic for the wavelet model

    Arguments:
        model {pytorch model}       -- the model to be trained
        optimizer {torch optim}     -- optimiser to be used
        loss_fn                        -- loss_fn function
        train_loader {dataloader}   -- training dataloader
        val_loader {dataloader}     -- validation dataloader
        log_dir {str}               -- the log directory
        device {torch.device}       -- the device to be used e.g. cpu or cuda
        epochs {int}                -- the number of epochs
        log_interval {int}          -- the log interval for train batch loss

    Keyword Arguments:
        load_weight_path {str} -- Model weight path to be loaded (default: {None})
        save_graph {bool}      -- whether to save the model graph (default: {False})

    Returns:
        None
    """
    model.to(device)
    if load_weight_path is not None:
        model.load_state_dict(torch.load(load_weight_path))

    optimizer = optimizer(model.parameters())

    def process_function(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, _ = batch
        x = x.to(device)
        y = model(x)
        loss = loss_fn(y, x)
        loss.backward()
        optimizer.step()
        return loss.item()

    def evaluate_function(engine, batch):
        model.eval()
        with torch.no_grad():
            x, _ = batch
            x = x.to(device)
            y = model(x)
            loss = loss_fn(y,x)
            return loss.item()

    trainer = Engine(process_function)
    evaluator = Engine(evaluate_function)

    RunningAverage(output_transform=lambda x:x).attach(trainer,'loss')
    RunningAverage(output_transform=lambda x:x).attach(evaluator,'loss')


    writer = create_summary_writer(model, train_loader, log_dir,
                                   save_graph, device)

    def score_function(engine):
        return -engine.state.metrics['loss']

    to_save = {'model': model}
    handler = Checkpoint(
        to_save,
        DiskSaver(os.path.join(log_dir, 'models'), create_dir=True),
        n_saved=5, filename_prefix='best', score_function=score_function,
        score_name="loss",
        global_step_transform=global_step_from_engine(trainer))

    evaluator.add_event_handler(Events.COMPLETED, handler)

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        print(
            f"Epoch[{engine.state.epoch}] Iteration[{engine.state.iteration}/"
            f"{len(train_loader)}] Loss: {engine.state.output:.3f}"
        )
        writer.add_scalar("training/loss", engine.state.output,
                          engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_loss = metrics["loss"]
        print(
            f"Training Results - Epoch: {engine.state.epoch} Avg loss: {avg_loss:.3f}"
        )
        writer.add_scalar("training/avg_loss", avg_loss, engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_loss = metrics["loss"]

        print(
            f"Validation Results - Epoch: {engine.state.epoch} Avg loss: {avg_loss:.3f}"
        )
        writer.add_scalar("validation/avg_loss", avg_loss, engine.state.epoch)

    trainer.run(train_loader, max_epochs=epochs)

    writer.close()