예제 #1
0
def test_poptorch_models_at_different_stages(tmpdir):
    plugin = IPUStrategy()
    trainer = Trainer(default_root_dir=tmpdir,
                      strategy=plugin,
                      accelerator="ipu",
                      devices=8)
    model = BoringModel()
    model.trainer = trainer
    plugin.model = model

    trainer.optimizers = model.configure_optimizers()[0]
    trainer.state.fn = TrainerFn.FITTING
    trainer.strategy.setup(trainer)
    assert list(trainer.strategy.poptorch_models) == [
        RunningStage.TRAINING, RunningStage.VALIDATING
    ]

    for fn, stage in (
        (TrainerFn.VALIDATING, RunningStage.VALIDATING),
        (TrainerFn.TESTING, RunningStage.TESTING),
        (TrainerFn.PREDICTING, RunningStage.PREDICTING),
    ):
        trainer.state.fn = fn
        trainer.state.stage = stage
        trainer.strategy.setup(trainer)
        assert list(trainer.strategy.poptorch_models) == [stage]
예제 #2
0
def test_dataloader_source_request_from_module():
    """Test requesting a dataloader from a module works."""
    module = BoringModel()
    module.trainer = Trainer()
    module.foo = Mock(return_value=module.train_dataloader())

    source = _DataLoaderSource(module, "foo")
    assert source.is_module()
    module.foo.assert_not_called()
    assert isinstance(source.dataloader(), DataLoader)
    module.foo.assert_called_once()
예제 #3
0
def test_eval_limit_batches(stage, mode, limit_batches):
    limit_eval_batches = f"limit_{mode}_batches"
    dl_hook = f"{mode}_dataloader"
    model = BoringModel()
    eval_loader = getattr(model, dl_hook)()

    trainer = Trainer(**{limit_eval_batches: limit_batches})
    model.trainer = trainer
    trainer._data_connector.attach_dataloaders(model)
    loader_num_batches, dataloaders = trainer._data_connector._reset_eval_dataloader(stage, model=model)
    expected_batches = int(limit_batches * len(eval_loader)) if isinstance(limit_batches, float) else limit_batches
    assert loader_num_batches[0] == expected_batches
    assert len(dataloaders[0]) == len(eval_loader)
def test_distributed_sampler_with_overfit_batches():
    model = BoringModel()
    trainer = Trainer(
        overfit_batches=1,
        strategy="ddp_spawn",
        num_processes=2,
    )
    model.trainer = trainer
    trainer.model = model
    trainer._data_connector.attach_dataloaders(model)
    trainer.reset_train_dataloader()
    train_sampler = trainer.train_dataloader.loaders.sampler
    assert isinstance(train_sampler, DistributedSampler)
    assert train_sampler.shuffle is False
예제 #5
0
def test_replication_factor(tmpdir):
    """Ensure if the user passes manual poptorch Options with custom parameters set, we set them correctly in the
    dataloaders."""

    plugin = IPUStrategy()
    trainer = Trainer(accelerator="ipu",
                      devices=2,
                      default_root_dir=tmpdir,
                      fast_dev_run=True,
                      strategy=plugin)
    assert isinstance(trainer.accelerator, IPUAccelerator)
    assert trainer.num_devices == 2
    assert trainer.strategy.replication_factor == 2

    model = BoringModel()
    training_opts = poptorch.Options()
    inference_opts = poptorch.Options()
    training_opts.replicationFactor(8)
    inference_opts.replicationFactor(7)
    plugin = IPUStrategy(inference_opts=inference_opts,
                         training_opts=training_opts)

    trainer = Trainer(default_root_dir=tmpdir,
                      accelerator="ipu",
                      devices=1,
                      strategy=plugin)
    trainer.optimizers = model.configure_optimizers()[0]
    plugin.model = model
    model.trainer = trainer
    trainer.state.fn = TrainerFn.FITTING
    trainer.strategy.setup(trainer)

    trainer.state.stage = RunningStage.TRAINING
    assert trainer.strategy.replication_factor == 8
    trainer.state.stage = RunningStage.VALIDATING
    assert trainer.strategy.replication_factor == 7

    for fn, stage in (
        (TrainerFn.VALIDATING, RunningStage.VALIDATING),
        (TrainerFn.TESTING, RunningStage.TESTING),
        (TrainerFn.PREDICTING, RunningStage.PREDICTING),
    ):
        trainer.state.fn = fn
        trainer.state.stage = stage
        trainer.strategy.setup(trainer)
        assert trainer.strategy.replication_factor == 7
예제 #6
0
def test_poptorch_models_at_different_stages(tmpdir):
    plugin = IPUPlugin()
    trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, ipus=8)
    model = BoringModel()
    model.trainer = trainer
    plugin.model = model

    trainer.optimizers = model.configure_optimizers()[0]
    trainer.state.fn = TrainerFn.FITTING
    trainer.training_type_plugin.pre_dispatch()
    assert list(trainer.training_type_plugin.poptorch_models) == [
        RunningStage.TRAINING, RunningStage.VALIDATING
    ]

    for fn, stage in (
        (TrainerFn.VALIDATING, RunningStage.VALIDATING),
        (TrainerFn.TESTING, RunningStage.TESTING),
        (TrainerFn.PREDICTING, RunningStage.PREDICTING),
    ):
        trainer.state.fn = fn
        trainer.state.stage = stage
        trainer.training_type_plugin.pre_dispatch()
        assert list(trainer.training_type_plugin.poptorch_models) == [stage]
예제 #7
0
def test_default_level_for_hooks_that_support_logging():
    def _make_assertion(model, hooks, result_mock, on_step, on_epoch,
                        extra_kwargs):
        for hook in hooks:
            model._current_fx_name = hook
            model.log(hook, 1)
            result_mock.assert_called_with(hook,
                                           hook,
                                           torch.tensor(1),
                                           on_step=on_step,
                                           on_epoch=on_epoch,
                                           **extra_kwargs)

    trainer = Trainer()
    model = BoringModel()
    model.trainer = trainer
    extra_kwargs = {
        k: ANY
        for k in inspect.signature(_ResultCollection.log).parameters
        if k not in ["self", "fx", "name", "value", "on_step", "on_epoch"]
    }
    all_logging_hooks = {
        k
        for k in _FxValidator.functions if _FxValidator.functions[k]
    }

    with mock.patch(
            "pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection.log",
            return_value=None) as result_mock:
        trainer.state.stage = RunningStage.TRAINING
        hooks = [
            "on_before_backward",
            "backward",
            "on_after_backward",
            "on_before_optimizer_step",
            "optimizer_step",
            "on_before_zero_grad",
            "optimizer_zero_grad",
            "training_step",
            "training_step_end",
            "on_batch_start",
            "on_batch_end",
            "on_train_batch_start",
            "on_train_batch_end",
        ]
        all_logging_hooks = all_logging_hooks - set(hooks)
        _make_assertion(model,
                        hooks,
                        result_mock,
                        on_step=True,
                        on_epoch=False,
                        extra_kwargs=extra_kwargs)

        hooks = [
            "on_train_start",
            "on_train_epoch_start",
            "on_train_epoch_end",
            "on_epoch_start",
            "on_epoch_end",
            "training_epoch_end",
        ]
        all_logging_hooks = all_logging_hooks - set(hooks)
        _make_assertion(model,
                        hooks,
                        result_mock,
                        on_step=False,
                        on_epoch=True,
                        extra_kwargs=extra_kwargs)

        trainer.state.stage = RunningStage.VALIDATING
        trainer.state.fn = TrainerFn.VALIDATING
        hooks = [
            "on_validation_start",
            "on_validation_epoch_start",
            "on_validation_epoch_end",
            "on_validation_batch_start",
            "on_validation_batch_end",
            "validation_step",
            "validation_step_end",
            "validation_epoch_end",
        ]
        all_logging_hooks = all_logging_hooks - set(hooks)
        _make_assertion(model,
                        hooks,
                        result_mock,
                        on_step=False,
                        on_epoch=True,
                        extra_kwargs=extra_kwargs)

        trainer.state.stage = RunningStage.TESTING
        trainer.state.fn = TrainerFn.TESTING
        hooks = [
            "on_test_start",
            "on_test_epoch_start",
            "on_test_epoch_end",
            "on_test_batch_start",
            "on_test_batch_end",
            "test_step",
            "test_step_end",
            "test_epoch_end",
        ]
        all_logging_hooks = all_logging_hooks - set(hooks)
        _make_assertion(model,
                        hooks,
                        result_mock,
                        on_step=False,
                        on_epoch=True,
                        extra_kwargs=extra_kwargs)

    # just to ensure we checked all possible logging hooks here
    assert len(all_logging_hooks) == 0