Esempio n. 1
0
def test_replication_factor(tmpdir):
    """Ensure if the user passes manual poptorch Options with custom parameters set, we set them correctly in the
    dataloaders."""

    plugin = IPUPlugin()
    trainer = Trainer(ipus=2,
                      default_root_dir=tmpdir,
                      fast_dev_run=True,
                      strategy=plugin)
    assert trainer.ipus == 2
    assert trainer.training_type_plugin.replication_factor == 2

    model = BoringModel()
    training_opts = poptorch.Options()
    inference_opts = poptorch.Options()
    training_opts.replicationFactor(8)
    inference_opts.replicationFactor(7)
    plugin = IPUPlugin(inference_opts=inference_opts,
                       training_opts=training_opts)

    trainer = Trainer(default_root_dir=tmpdir, ipus=1, strategy=plugin)
    trainer.optimizers = model.configure_optimizers()[0]
    plugin.model = model
    model.trainer = trainer
    trainer.state.fn = TrainerFn.FITTING
    trainer.training_type_plugin.pre_dispatch()

    trainer.state.stage = RunningStage.TRAINING
    assert trainer.training_type_plugin.replication_factor == 8
    trainer.state.stage = RunningStage.VALIDATING
    assert trainer.training_type_plugin.replication_factor == 7

    for fn, stage in (
        (TrainerFn.VALIDATING, RunningStage.VALIDATING),
        (TrainerFn.TESTING, RunningStage.TESTING),
        (TrainerFn.PREDICTING, RunningStage.PREDICTING),
    ):
        trainer.state.fn = fn
        trainer.state.stage = stage
        trainer.training_type_plugin.pre_dispatch()
        assert trainer.training_type_plugin.replication_factor == 7
Esempio n. 2
0
def test_poptorch_models_at_different_stages(tmpdir):
    plugin = IPUPlugin()
    trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, ipus=8)
    model = BoringModel()
    model.trainer = trainer
    plugin.model = model

    trainer.optimizers = model.configure_optimizers()[0]
    trainer.state.fn = TrainerFn.FITTING
    trainer.training_type_plugin.pre_dispatch()
    assert list(trainer.training_type_plugin.poptorch_models) == [
        RunningStage.TRAINING, RunningStage.VALIDATING
    ]

    for fn, stage in (
        (TrainerFn.VALIDATING, RunningStage.VALIDATING),
        (TrainerFn.TESTING, RunningStage.TESTING),
        (TrainerFn.PREDICTING, RunningStage.PREDICTING),
    ):
        trainer.state.fn = fn
        trainer.state.stage = stage
        trainer.training_type_plugin.pre_dispatch()
        assert list(trainer.training_type_plugin.poptorch_models) == [stage]