示例#1
0
文件: common.py 项目: realathu/ignite
def _setup_common_distrib_training_handlers(
    trainer,
    train_sampler=None,
    to_save=None,
    save_every_iters=1000,
    output_path=None,
    lr_scheduler=None,
    with_gpu_stats=True,
    output_names=None,
    with_pbars=True,
    with_pbar_on_iters=True,
    log_every_iters=100,
    device="cuda",
):
    if not (dist.is_available() and dist.is_initialized()):
        raise RuntimeError(
            "Distributed setting is not initialized, please call `dist.init_process_group` before."
        )

    _setup_common_training_handlers(
        trainer,
        to_save=None,
        lr_scheduler=lr_scheduler,
        with_gpu_stats=with_gpu_stats,
        output_names=output_names,
        with_pbars=(dist.get_rank() == 0) and with_pbars,
        with_pbar_on_iters=with_pbar_on_iters,
        log_every_iters=log_every_iters,
        device=device,
    )

    if train_sampler is not None:
        if not callable(getattr(train_sampler, "set_epoch", None)):
            raise TypeError("Train sampler should have `set_epoch` method")

        @trainer.on(Events.EPOCH_STARTED)
        def distrib_set_epoch(engine):
            train_sampler.set_epoch(engine.state.epoch - 1)

    if dist.get_rank() == 0:
        if to_save is not None:
            if output_path is None:
                raise ValueError(
                    "If to_save argument is provided then output_path argument should be also defined"
                )
            checkpoint_handler = ModelCheckpoint(dirname=output_path,
                                                 filename_prefix="training",
                                                 require_empty=False)
            trainer.add_event_handler(
                Events.ITERATION_COMPLETED(every=save_every_iters),
                checkpoint_handler, to_save)
示例#2
0
文件: common.py 项目: zivzone/ignite
def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters):
    if optimizers is not None:
        from torch.optim.optimizer import Optimizer

        if not isinstance(optimizers, (Optimizer, Mapping)):
            raise TypeError("Argument optimizers should be either a single optimizer or a dictionary or optimizers")

    if evaluators is not None:
        if not isinstance(evaluators, (Engine, Mapping)):
            raise TypeError("Argument optimizers should be either a single optimizer or a dictionary or optimizers")

    if log_every_iters is None:
        log_every_iters = 1

    logger.attach(trainer,
                  log_handler=logger_module.OutputHandler(tag="training", metric_names='all'),
                  event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

    if optimizers is not None:
        # Log optimizer parameters
        if isinstance(optimizers, Optimizer):
            optimizers = {None: optimizers}

        for k, optimizer in optimizers.items():
            logger.attach(trainer,
                          log_handler=logger_module.OptimizerParamsHandler(optimizer, param_name="lr", tag=k),
                          event_name=Events.ITERATION_STARTED(every=log_every_iters))

    if evaluators is not None:
        # Log evaluation metrics
        if isinstance(evaluators, Engine):
            evaluators = {"validation": evaluators}

        for k, evaluator in evaluators.items():
            gst = global_step_from_engine(trainer)
            logger.attach(evaluator,
                          log_handler=logger_module.OutputHandler(tag=k, metric_names='all', global_step_transform=gst),
                          event_name=Events.COMPLETED)
示例#3
0
def test_pbar_wrong_events_order():

    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(ValueError,
                       match="should be called before closing event"):
        pbar.attach(engine,
                    event_name=Events.COMPLETED,
                    closing_event_name=Events.COMPLETED)

    with pytest.raises(ValueError,
                       match="should be called before closing event"):
        pbar.attach(
            engine,
            event_name=Events.COMPLETED,
            closing_event_name=Events.EPOCH_COMPLETED,
        )

    with pytest.raises(ValueError,
                       match="should be called before closing event"):
        pbar.attach(
            engine,
            event_name=Events.COMPLETED,
            closing_event_name=Events.ITERATION_COMPLETED,
        )

    with pytest.raises(ValueError,
                       match="should be called before closing event"):
        pbar.attach(
            engine,
            event_name=Events.EPOCH_COMPLETED,
            closing_event_name=Events.EPOCH_COMPLETED,
        )

    with pytest.raises(ValueError,
                       match="should be called before closing event"):
        pbar.attach(
            engine,
            event_name=Events.ITERATION_COMPLETED,
            closing_event_name=Events.ITERATION_STARTED,
        )

    with pytest.raises(ValueError,
                       match="Closing event should not use any event filter"):
        pbar.attach(
            engine,
            event_name=Events.ITERATION_STARTED,
            closing_event_name=Events.EPOCH_COMPLETED(every=10),
        )
示例#4
0
def test_lr_suggestion_mnist(lr_finder, mnist_to_save, dummy_engine_mnist,
                             mnist_dataloader):

    max_iters = 50

    with lr_finder.attach(dummy_engine_mnist,
                          mnist_to_save) as trainer_with_finder:

        with trainer_with_finder.add_event_handler(
                Events.ITERATION_COMPLETED(once=max_iters),
                lambda _: trainer_with_finder.terminate()):
            trainer_with_finder.run(mnist_dataloader)

    assert 1e-4 <= lr_finder.lr_suggestion() <= 2
示例#5
0
def create_vae_engines(
    model,
    optimizer,
    criterion=None,
    metrics=None,
    device=None,
    non_blocking=False,
    fig_dir=None,
    unflatten=None,
):

    device = model.device
    if criterion is None:
        criterion = get_default_autoencoder_loss()

    train_step = create_vae_train_step(model,
                                       optimizer,
                                       criterion,
                                       device=device,
                                       non_blocking=non_blocking)
    eval_step = create_vae_eval_step(model,
                                     device=device,
                                     non_blocking=non_blocking)

    if metrics is None:
        metrics = {}
    metrics.setdefault(
        "loss",
        Loss(criterion, output_transform=loss_eval_output_transform),
    )
    metrics.setdefault("mse",
                       MeanSquaredError(output_transform=lambda x: x[:2]))
    trainer = Engine(train_step)
    evaluator = create_autoencoder_evaluator(eval_step, metrics=metrics)

    save_image_callback = create_save_image_callback(fig_dir,
                                                     unflatten=unflatten)

    def _epoch_getter():
        return trainer.state.__dict__.get("epoch", None)

    evaluator.add_event_handler(
        Events.ITERATION_COMPLETED(once=1),
        save_image_callback,
        epoch=_epoch_getter,
    )

    val_log_handler, val_logger = create_log_handler(trainer)

    return trainer, evaluator, val_log_handler, val_logger
示例#6
0
def test_as_context_manager():

    n_epochs = 5
    data = list(range(50))

    class _DummyLogger(DummyLogger):
        def __init__(self, writer):
            self.writer = writer

        def close(self):
            self.writer.close()

    def _test(event, n_calls):
        global close_counter
        close_counter = 0

        losses = torch.rand(n_epochs * len(data))
        losses_iter = iter(losses)

        def update_fn(engine, batch):
            return next(losses_iter)

        writer = MagicMock()
        writer.close = MagicMock()

        with _DummyLogger(writer) as logger:
            assert isinstance(logger, _DummyLogger)

            trainer = Engine(update_fn)
            mock_log_handler = MagicMock()

            logger.attach(trainer,
                          log_handler=mock_log_handler,
                          event_name=event)

            trainer.run(data, max_epochs=n_epochs)

            mock_log_handler.assert_called_with(trainer, logger, event)
            assert mock_log_handler.call_count == n_calls

        writer.close.assert_called_once_with()

    _test(Events.ITERATION_STARTED, len(data) * n_epochs)
    _test(Events.ITERATION_COMPLETED, len(data) * n_epochs)
    _test(Events.EPOCH_STARTED, n_epochs)
    _test(Events.EPOCH_COMPLETED, n_epochs)
    _test(Events.STARTED, 1)
    _test(Events.COMPLETED, 1)

    _test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)
示例#7
0
def test_callable_events_with_wrong_inputs():

    with pytest.raises(
            ValueError,
            match=r"Only one of the input arguments should be specified"):
        Events.ITERATION_STARTED()

    with pytest.raises(
            ValueError,
            match=r"Only one of the input arguments should be specified"):
        Events.ITERATION_STARTED(event_filter="123", every=12)

    with pytest.raises(TypeError,
                       match=r"Argument event_filter should be a callable"):
        Events.ITERATION_STARTED(event_filter="123")

    with pytest.raises(
            ValueError,
            match=r"Argument every should be integer and greater than one"):
        Events.ITERATION_STARTED(every=-1)

    with pytest.raises(ValueError, match=r"but will be called with"):
        Events.ITERATION_STARTED(event_filter=lambda x: x)
def create_trainer_and_evaluators(
    model: nn.Module,
    optimizer: Optimizer,
    criterion: nn.Module,
    data_loaders: Dict[str, DataLoader],
    metrics: Dict[str, Metric],
    config: ConfigSchema,
    logger: Logger,
) -> Tuple[Engine, Dict[str, Engine]]:
    trainer = get_trainer(model, criterion, optimizer)
    trainer.logger = logger

    evaluators = get_evaluators(model, metrics)
    setup_evaluation(trainer, evaluators, data_loaders, logger)

    lr_scheduler = get_lr_scheduler(config, optimizer, trainer, evaluators["val"])

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }

    common.setup_common_training_handlers(
        trainer=trainer,
        to_save=to_save,
        save_every_iters=config.checkpoint_every,
        save_handler=get_save_handler(config),
        with_pbars=False,
        train_sampler=data_loaders["train"].sampler,
    )
    trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler)
    ProgressBar(persist=False).attach(
        trainer,
        metric_names="all",
        event_name=Events.ITERATION_COMPLETED(every=config.log_every_iters),
    )

    resume_from = config.resume_from
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix()
        )
        logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer, evaluators
示例#9
0
def test_attach():

    n_epochs = 5
    data = list(range(50))

    def _test(event, n_calls, kwargs={}):

        losses = torch.rand(n_epochs * len(data))
        losses_iter = iter(losses)

        def update_fn(engine, batch):
            return next(losses_iter)

        trainer = Engine(update_fn)

        logger = DummyLogger()

        mock_log_handler = MagicMock()

        logger.attach(trainer,
                      log_handler=mock_log_handler,
                      event_name=event,
                      **kwargs)

        trainer.run(data, max_epochs=n_epochs)

        if isinstance(event, EventsList):
            events = [e for e in event]
        else:
            events = [event]

        if len(kwargs) > 0:
            calls = [call(trainer, logger, e, **kwargs) for e in events]
        else:
            calls = [call(trainer, logger, e) for e in events]

        mock_log_handler.assert_has_calls(calls)
        assert mock_log_handler.call_count == n_calls

    _test(Events.ITERATION_STARTED, len(data) * n_epochs, kwargs={"a": 0})
    _test(Events.ITERATION_COMPLETED, len(data) * n_epochs)
    _test(Events.EPOCH_STARTED, n_epochs)
    _test(Events.EPOCH_COMPLETED, n_epochs)
    _test(Events.STARTED, 1)
    _test(Events.COMPLETED, 1)

    _test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)

    _test(Events.STARTED | Events.COMPLETED, 2)
示例#10
0
def run(subj_ind: int,
        result_name: str,
        dataset_path: str,
        deep4_path: str,
        result_path: str,
        config: dict = default_config,
        model_builder: ProgressiveModelBuilder = default_model_builder):
    result_path_subj = os.path.join(result_path, result_name, str(subj_ind))
    os.makedirs(result_path_subj, exist_ok=True)

    joblib.dump(config,
                os.path.join(result_path_subj, 'config.dict'),
                compress=False)
    joblib.dump(model_builder,
                os.path.join(result_path_subj, 'model_builder.jblb'),
                compress=True)

    # create discriminator and generator modules
    discriminator = model_builder.build_discriminator()
    generator = model_builder.build_generator()

    # initiate weights
    generator.apply(weight_filler)
    discriminator.apply(weight_filler)

    # trainer engine
    trainer = GanSoftplusTrainer(10, discriminator, generator,
                                 config['r1_gamma'], config['r2_gamma'])

    # handles potential progression after each epoch
    progression_handler = ProgressionHandler(
        discriminator,
        generator,
        config['n_stages'],
        config['use_fade'],
        config['n_epochs_fade'],
        freeze_stages=config['freeze_stages'])
    progression_handler.set_progression(0, 1.)
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                              progression_handler.advance_alpha)

    generator.train()
    discriminator.train()

    train(subj_ind, dataset_path, deep4_path, result_path_subj,
          progression_handler, trainer, config['n_batch'], config['lr_d'],
          config['lr_g'], config['betas'], config['n_epochs_per_stage'],
          config['n_epochs_metrics'], config['plot_every_epoch'],
          config['orig_fs'])
示例#11
0
    def train_epochs(self, max_epochs):
        self.trainer = Engine(self.train_one_step)
        self.evaluator = Engine(self.evaluate_one_step)
        self.metrics = {'Loss': Loss(self.criterion), 'Acc': Accuracy()}
        for name, metric in self.metrics.items():
            metric.attach(self.evaluator, name)

        with SummaryWriter(
                log_dir="/tmp/tensorboard/Transform" +
                str(type(self))[17:len(str(type(self))) - 2]) as writer:

            @self.trainer.on(Events.EPOCH_COMPLETED(every=1))  # Cada 1 epocas
            def log_results(engine):
                # Evaluo el conjunto de entrenamiento
                self.eval()
                self.evaluator.run(self.train_loader)
                writer.add_scalar("train/loss",
                                  self.evaluator.state.metrics['Loss'],
                                  engine.state.epoch)
                writer.add_scalar("train/accy",
                                  self.evaluator.state.metrics['Acc'],
                                  engine.state.epoch)

                # Evaluo el conjunto de validación
                self.evaluator.run(self.valid_loader)
                writer.add_scalar("valid/loss",
                                  self.evaluator.state.metrics['Loss'],
                                  engine.state.epoch)
                writer.add_scalar("valid/accy",
                                  self.evaluator.state.metrics['Acc'],
                                  engine.state.epoch)
                self.train()

            # Guardo el mejor modelo en validación
            best_model_handler = ModelCheckpoint(
                dirname='.',
                require_empty=False,
                filename_prefix="best",
                n_saved=1,
                score_function=lambda engine: -engine.state.metrics['Loss'],
                score_name="val_loss")
            # Lo siguiente se ejecuta cada ves que termine el loop de validación
            self.evaluator.add_event_handler(
                Events.COMPLETED, best_model_handler, {
                    f'Transform{str(type(self))[17:len(str(type(self)))-2]}':
                    model
                })

        self.trainer.run(self.train_loader, max_epochs=max_epochs)
示例#12
0
def test_remove_event_handler_on_callable_events():

    engine = Engine(lambda e, b: 1)

    def foo(e):
        pass

    assert not engine.has_event_handler(foo)

    engine.add_event_handler(Events.EPOCH_STARTED, foo)
    assert engine.has_event_handler(foo)
    engine.remove_event_handler(foo, Events.EPOCH_STARTED)
    assert not engine.has_event_handler(foo)

    def bar(e):
        pass

    engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar)
    assert engine.has_event_handler(bar)
    engine.remove_event_handler(bar, Events.EPOCH_COMPLETED)
    assert not engine.has_event_handler(foo)

    with pytest.raises(TypeError, match=r"Argument event_name should not be a filtered event"):
        engine.remove_event_handler(bar, Events.EPOCH_COMPLETED(every=3))
示例#13
0
 def set_defaults(self, is_training=True):
     """
     Fill in the default events for training or evaluation specs
     """
     if self.metrics is None:
         self.metrics = {}
     if self.plot_event == 'default':
         self.plot_event = None  #Events.EPOCH_COMPLETED
     if is_training:
         # Log and print every 100 training iterations
         if self.log_event == 'default':
             self.log_event = Events.ITERATION_COMPLETED(every=100)
         if self.print_event == 'default':
             self.print_event = Events.ITERATION_COMPLETED(every=100)
         if self.print_fmt == 'default':
             self.print_fmt = TRAIN_MESSAGE
     else:
         # Log and print at the end of each evaluation
         if self.log_event == 'default':
             self.log_event = Events.EPOCH_COMPLETED
         if self.print_event == 'default':
             self.print_event = Events.EPOCH_COMPLETED
         if self.print_fmt == 'default':
             self.print_fmt = EVAL_MESSAGE
示例#14
0
def test_ema_two_handlers(get_dummy_model):
    """Test when two EMA handlers are attached to a trainer"""
    model_1 = get_dummy_model()
    # momentum will be constantly 0.5
    ema_handler_1 = EMAHandler(model_1,
                               momentum_warmup=0.5,
                               momentum=0.5,
                               warmup_iters=1)

    model_2 = get_dummy_model()
    ema_handler_2 = EMAHandler(model_2,
                               momentum_warmup=0.5,
                               momentum=0.5,
                               warmup_iters=1)

    def _step_fn(engine: Engine, batch: Any):
        model_1.weight.data.add_(1)
        model_2.weight.data.add_(1)
        return 0

    engine = Engine(_step_fn)
    assert not hasattr(engine.state, "ema_momentum_1")
    # handler_1 update EMA model of model_1 every 1 iteration
    ema_handler_1.attach(engine,
                         "ema_momentum_1",
                         event=Events.ITERATION_COMPLETED)
    assert hasattr(engine.state, "ema_momentum_1")

    # handler_2 update EMA model for model_2 every 2 iterations
    ema_handler_2.attach(engine,
                         "ema_momentum_2",
                         event=Events.ITERATION_COMPLETED(every=2))
    assert hasattr(engine.state, "ema_momentum_2")

    # engine will run 4 iterations
    engine.run(range(2), max_epochs=2)
    ema_weight_1 = ema_handler_1.ema_model.weight.data
    ema_weight_2 = ema_handler_2.ema_model.weight.data
    torch.testing.assert_allclose(ema_weight_1, torch.full((1, 2), 4.0625))
    torch.testing.assert_allclose(ema_weight_2, torch.full((1, 2), 3.5))

    assert engine.state.ema_momentum_1 == 0.5
    assert engine.state.ema_momentum_2 == 0.5

    model_3 = get_dummy_model()
    ema_handler_3 = EMAHandler(model_3)
    with pytest.raises(ValueError, match="Please select another name"):
        ema_handler_3.attach(engine, "ema_momentum_2")
示例#15
0
def test_mnist_lr_suggestion(lr_finder, mnist_model, mnist_optimizer,
                             mnist_dataloader):
    criterion = nn.CrossEntropyLoss()
    trainer = create_supervised_trainer(mnist_model, mnist_optimizer,
                                        criterion)
    to_save = {"model": mnist_model, "optimizer": mnist_optimizer}

    max_iters = 50

    with lr_finder.attach(trainer, to_save) as trainer_with_finder:

        with trainer_with_finder.add_event_handler(
                Events.ITERATION_COMPLETED(once=max_iters),
                lambda _: trainer_with_finder.terminate()):
            trainer_with_finder.run(mnist_dataloader)

    assert 1e-4 <= lr_finder.lr_suggestion() <= 10
示例#16
0
def test_pbar_on_callable_events(capsys):

    n_epochs = 1
    loader = list(range(100))
    engine = Engine(update_fn)

    pbar = ProgressBar()
    pbar.attach(engine, event_name=Events.ITERATION_STARTED(every=10), closing_event_name=Events.EPOCH_COMPLETED)
    engine.run(loader, max_epochs=n_epochs)

    captured = capsys.readouterr()
    err = captured.err.split("\r")
    err = list(map(lambda x: x.strip(), err))
    err = list(filter(None, err))
    actual = err[-1]
    expected = "Iteration: [90/100]  90%|█████████  [00:00<00:00]"
    assert actual == expected
示例#17
0
def set_handlers(trainer: Engine, evaluator: Engine, valloader: DataLoader,
                 model: nn.Module, optimizer: optim.Optimizer,
                 args: Namespace) -> None:
    ROC_AUC(
        output_transform=lambda output: (output.logit, output.label)).attach(
            engine=evaluator, name='roc_auc')
    Accuracy(output_transform=lambda output: (
        (output.logit > 0).long(), output.label)).attach(engine=evaluator,
                                                         name='accuracy')
    Loss(loss_fn=nn.BCEWithLogitsLoss(),
         output_transform=lambda output:
         (output.logit, output.label.float())).attach(engine=evaluator,
                                                      name='loss')

    ProgressBar(persist=True, desc='Epoch').attach(
        engine=trainer, output_transform=lambda output: {'loss': output.loss})
    ProgressBar(persist=False, desc='Eval').attach(engine=evaluator)
    ProgressBar(persist=True, desc='Eval').attach(
        engine=evaluator,
        metric_names=['roc_auc', 'accuracy', 'loss'],
        event_name=Events.EPOCH_COMPLETED,
        closing_event_name=Events.COMPLETED)

    @trainer.on(Events.ITERATION_COMPLETED(every=args.evaluation_interval))
    def _evaluate(trainer: Engine):
        evaluator.run(valloader, max_epochs=1)

    evaluator.add_event_handler(
        event_name=Events.EPOCH_COMPLETED,
        handler=Checkpoint(
            to_save={
                'model': model,
                'optimizer': optimizer,
                'trainer': trainer
            },
            save_handler=DiskSaver(dirname=args.checkpoint_dir,
                                   atomic=True,
                                   create_dir=True,
                                   require_empty=False),
            filename_prefix='best',
            score_function=lambda engine: engine.state.metrics['roc_auc'],
            score_name='val_roc_auc',
            n_saved=1,
            global_step_transform=global_step_from_engine(trainer)))
示例#18
0
def test_neg_event_filter_threshold_handlers_profiler():
    true_event_handler_time = 0.1
    true_max_epochs = 1
    true_num_iters = 1

    profiler = HandlersTimeProfiler()
    dummy_trainer = Engine(_do_nothing_update_fn)
    profiler.attach(dummy_trainer)

    @dummy_trainer.on(Events.EPOCH_STARTED(once=2))
    def do_something_once_on_2_epoch():
        time.sleep(true_event_handler_time)

    dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs)
    results = profiler.get_results()
    event_results = results[0]
    assert "do_something_once_on_2_epoch" in event_results[0]
    assert event_results[1] == "EPOCH_STARTED"
    assert event_results[2] == "not triggered"
示例#19
0
def test_concepts_snippet_warning():
    def random_train_data_generator():
        while True:
            yield torch.randint(0, 100, size=(1, ))

    def print_train_data(engine, batch):
        i = engine.state.iteration
        e = engine.state.epoch
        print("train", e, i, batch.tolist())

    trainer = DeterministicEngine(print_train_data)

    @trainer.on(Events.ITERATION_COMPLETED(every=3))
    def user_handler(_):
        # handler synchronizes the random state
        torch.manual_seed(12)
        a = torch.rand(1)

    trainer.run(random_train_data_generator(), max_epochs=3, epoch_length=5)
示例#20
0
def test_ema_two_handlers(get_dummy_model):
    """Test when two EMA handlers are attached to a trainer"""
    model_1 = get_dummy_model()
    ema_handler_1 = EMAHandler(model_1, momentum=0.5)

    model_2 = get_dummy_model()
    ema_handler_2 = EMAHandler(model_2, momentum=0.5)

    def _step_fn(engine: Engine, batch: Any):
        model_1.weight.data.add_(1)
        model_2.weight.data.add_(1)
        return 0

    engine = Engine(_step_fn)
    assert not hasattr(engine.state, "ema_momentum_1")
    # handler_1 update EMA model of model_1 every 1 iteration
    ema_handler_1.attach(engine,
                         "ema_momentum_1",
                         event=Events.ITERATION_COMPLETED)
    assert hasattr(engine.state, "ema_momentum_1")

    # handler_2 update EMA model for model_2 every 2 iterations
    ema_handler_2.attach(engine,
                         "ema_momentum_2",
                         event=Events.ITERATION_COMPLETED(every=2))
    assert hasattr(engine.state, "ema_momentum_2")

    # engine will run 4 iterations
    engine.run(range(2), max_epochs=2)
    # explicitly cast to float32 to avoid test failure on XLA devices
    ema_weight_1 = ema_handler_1.ema_model.weight.data.to(torch.float32)
    ema_weight_2 = ema_handler_2.ema_model.weight.data.to(torch.float32)
    assert ema_weight_1.allclose(ema_weight_1.new_full((1, 2), 4.0625))
    assert ema_weight_2.allclose(ema_weight_2.new_full((1, 2), 3.5))

    assert engine.state.ema_momentum_1 == 0.5
    assert engine.state.ema_momentum_2 == 0.5

    model_3 = get_dummy_model()
    ema_handler_3 = EMAHandler(model_3)
    with pytest.warns(UserWarning,
                      match="Attribute 'ema_momentum_1' already exists"):
        ema_handler_3.attach(engine, name="ema_momentum_1")
示例#21
0
def setup_evaluation(
    trainer: Engine,
    evaluators: Dict[str, Engine],
    data_loaders: Dict[str, DataLoader],
    logger: Logger,
) -> None:
    # We define two evaluators as they wont have exactly similar roles:
    # - `evaluator` will save the best model based on validation score
    def _evaluation(engine: Engine) -> None:
        epoch = trainer.state.epoch
        for split in ["train", "val", "test"]:
            state = evaluators[split].run(data_loaders[split])
            log_metrics(logger, epoch, state.times["COMPLETED"], split, state.metrics)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=config.validate_every) | Events.COMPLETED,
        _evaluation,
    )
    return
示例#22
0
def test_get_intermediate_results_during_run_basic_profiler(capsys):
    true_event_handler_time = 0.0645
    true_max_epochs = 2
    true_num_iters = 5

    profiler = BasicTimeProfiler()
    dummy_trainer = get_prepared_engine_for_basic_profiler(true_event_handler_time)
    profiler.attach(dummy_trainer)

    @dummy_trainer.on(Events.ITERATION_COMPLETED(every=3))
    def log_results(_):
        results = profiler.get_results()
        profiler.print_results(results)
        captured = capsys.readouterr()
        out = captured.out
        assert "BasicTimeProfiler._" not in out
        assert "nan" not in out
        assert " min/index: (0.0, " not in out, out

    dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs)
示例#23
0
def test_pos_event_filter_threshold_handlers_profiler():
    true_event_handler_time = HandlersTimeProfiler.EVENT_FILTER_THESHOLD_TIME
    true_max_epochs = 2
    true_num_iters = 1

    profiler = HandlersTimeProfiler()
    dummy_trainer = Engine(_do_nothing_update_fn)
    profiler.attach(dummy_trainer)

    @dummy_trainer.on(Events.EPOCH_STARTED(once=2))
    def do_something_once_on_2_epoch():
        time.sleep(true_event_handler_time)

    dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs)
    results = profiler.get_results()
    event_results = results[0]
    assert "do_something_once_on_2_epoch" in event_results[0]
    assert event_results[1] == "EPOCH_STARTED"
    assert event_results[2] == approx(
        (true_max_epochs * true_num_iters * true_event_handler_time) / 2, abs=1e-1
    )  # total
示例#24
0
def test_run_finite_iterator_no_epoch_length_2():
    # FR: https://github.com/pytorch/ignite/issues/871
    known_size = 11

    def finite_size_data_iter(size):
        for i in range(size):
            yield i

    bc = BatchChecker(data=list(range(known_size)))

    engine = Engine(lambda e, b: bc.check(b))

    @engine.on(Events.ITERATION_COMPLETED(every=known_size))
    def restart_iter():
        engine.state.dataloader = finite_size_data_iter(known_size)

    data_iter = finite_size_data_iter(known_size)
    engine.run(data_iter, max_epochs=5)

    assert engine.state.epoch == 5
    assert engine.state.iteration == known_size * 5
示例#25
0
def _test_ema_final_weight(model, device=None, ddp=False, interval=1):
    """Test if final smoothed weights are correct"""
    if device is None:
        # let horovod decide the device
        device = idist.device()
    if isinstance(device, str):
        device = torch.device(device)
    model = model.to(device)
    if ddp:
        model = idist.auto_model(model)
    step_fn = _get_dummy_step_fn(model)
    engine = Engine(step_fn)

    # momentum will be constantly 0.5
    ema_handler = EMAHandler(model,
                             momentum_warmup=0.5,
                             momentum=0.5,
                             warmup_iters=1)
    ema_handler.attach(engine,
                       "model",
                       event=Events.ITERATION_COMPLETED(every=interval))

    # engine will run 4 iterations
    engine.run(range(2), max_epochs=2)

    # ema_model and model can be DP or DDP
    ema_weight = _unwrap_model(ema_handler.ema_model).weight.data
    model_weight = _unwrap_model(model).weight.data
    assert ema_weight.device == device
    assert model_weight.device == device
    if interval == 1:
        torch.testing.assert_allclose(
            ema_weight, torch.full((1, 2), 4.0625, device=device))
    elif interval == 2:
        torch.testing.assert_allclose(ema_weight,
                                      torch.full((1, 2), 3.5, device=device))
    else:
        pass
    torch.testing.assert_allclose(model_weight,
                                  torch.full((1, 2), 5.0, device=device))
示例#26
0
def test_custom_event_with_arg_handlers_profiler():
    true_event_handler_time = 0.1
    true_max_epochs = 1
    true_num_iters = 2

    profiler = HandlersTimeProfiler()
    dummy_trainer = Engine(_do_nothing_update_fn)
    dummy_trainer.register_events("custom_event")
    profiler.attach(dummy_trainer)

    @dummy_trainer.on(Events.ITERATION_COMPLETED(every=1))
    def trigger_custom_event():
        dummy_trainer.fire_event("custom_event")

    args = [122, 324]

    @dummy_trainer.on("custom_event", args)
    def on_custom_event(args):
        time.sleep(true_event_handler_time)

    dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs)
    results = profiler.get_results()
    event_results = None
    for row in results:
        if row[1] == "custom_event":
            event_results = row
            break
    assert event_results is not None
    assert "on_custom_event" in event_results[0]

    assert event_results[2] == approx(true_max_epochs * true_num_iters *
                                      true_event_handler_time,
                                      abs=1e-1)  # total
    assert event_results[3][0] == approx(true_event_handler_time,
                                         abs=1e-1)  # min
    assert event_results[4][0] == approx(true_event_handler_time,
                                         abs=1e-1)  # max
    assert event_results[5] == approx(true_event_handler_time,
                                      abs=1e-1)  # mean
    assert event_results[6] == approx(0.0, abs=1e-1)  # stddev
示例#27
0
 def __init__(self,
              model,
              loss,
              optimizer,
              lr_scheduler,
              device,
              logger,
              log_interval,
              output_dir=None):
     self.logger = logger
     self.lr_scheduler = lr_scheduler
     self.log_interval = log_interval
     self.progress_bar_desc = "ITERATION - loss: {:.2f}"
     self.trainer_engine = create_supervised_trainer(model,
                                                     optimizer,
                                                     loss,
                                                     device=device)
     self.trainer_engine.add_event_handler(
         Events.ITERATION_COMPLETED(every=log_interval),
         self.log_training_loss)
     self.trainer_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           self.lr_step)
def setup_snapshots(trainer, sample_images, conf):
    # type: (Engine, SampleImages, DictConfig) -> None
    snapshots = conf.snapshots
    use_ema = conf.G_smoothing.enabled
    if snapshots.enabled:
        if use_ema:
            snap_event = Events.ITERATION_COMPLETED(
                every=snapshots.interval_iteration)
            snap_path = snapshots.get('save_dir',
                                      os.path.join(os.getcwd(), 'images'))
            if not os.path.exists(snap_path):
                os.makedirs(snap_path)
            logging.info("Saving snapshot images to {}".format(snap_path))
            trainer.add_event_handler(snap_event,
                                      handle_snapshot_images,
                                      sample_images,
                                      snap_path,
                                      dynamic_range=tuple(
                                          snapshots.dynamic_range))
        else:
            logging.warning(
                "Snapshot generation requires G_smoothing.enabled=true. "
                "Snapshots will be turned off for this run.")
示例#29
0
def train():
    set_seed(train_param.seed)
    model = Model(model_param)
    optimizer = AdamW(model.parameters(), lr=train_param.lr, eps=1e-8)
    update_steps = train_param.epoch * len(train_loader)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=update_steps)
    loss_fn = [translate, MSELoss()]
    device = torch.device(f'cuda:{train_param.device}')
    trainer = create_trainer(model, optimizer, scheduler, loss_fn,
                             train_param.grad_norm, device)
    train_evaluator = create_evaluator(model, metric, device)
    dev_evaluator = create_evaluator(model, metric, device)
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED(every=train_param.interval),
        log_training_loss)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, log_results,
                              *(train_evaluator, train_loader, 'Train'))
    trainer.add_event_handler(Events.EPOCH_COMPLETED, log_results,
                              *(dev_evaluator, dev_loader, 'Dev'))
    es_handler = EarlyStopping(patience=train_param.patience,
                               score_function=score_fn,
                               trainer=trainer)
    dev_evaluator.add_event_handler(Events.COMPLETED, es_handler)
    ckpt_handler = ModelCheckpoint(train_param.save_path,
                                   '',
                                   score_function=score_fn,
                                   score_name='score',
                                   require_empty=False)
    dev_evaluator.add_event_handler(Events.COMPLETED, ckpt_handler, {
        'model': model,
        'param': model_param
    })
    print(
        f'Start running {train_param.save_path.split("/")[-1]} at device: {train_param.device}\t'
        f'lr: {train_param.lr}')
    trainer.run(train_loader, max_epochs=train_param.epoch)
示例#30
0
def test_attach():

    n_epochs = 5
    data = list(range(50))

    def _test(event, n_calls):

        losses = torch.rand(n_epochs * len(data))
        losses_iter = iter(losses)

        def update_fn(engine, batch):
            return next(losses_iter)

        trainer = Engine(update_fn)

        logger = DummyLogger()

        mock_log_handler = MagicMock()

        logger.attach(trainer, log_handler=mock_log_handler, event_name=event)

        trainer.run(data, max_epochs=n_epochs)

        if isinstance(event, EventWithFilter):
            event = event.event

        mock_log_handler.assert_called_with(trainer, logger, event)
        assert mock_log_handler.call_count == n_calls

    _test(Events.ITERATION_STARTED, len(data) * n_epochs)
    _test(Events.ITERATION_COMPLETED, len(data) * n_epochs)
    _test(Events.EPOCH_STARTED, n_epochs)
    _test(Events.EPOCH_COMPLETED, n_epochs)
    _test(Events.STARTED, 1)
    _test(Events.COMPLETED, 1)

    _test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)