Esempio n. 1
0
def test_poptorch_models_at_different_stages(tmpdir):
    plugin = IPUStrategy()
    trainer = Trainer(default_root_dir=tmpdir,
                      strategy=plugin,
                      accelerator="ipu",
                      devices=8)
    model = BoringModel()
    model.trainer = trainer
    plugin.model = model

    trainer.optimizers = model.configure_optimizers()[0]
    trainer.state.fn = TrainerFn.FITTING
    trainer.strategy.setup(trainer)
    assert list(trainer.strategy.poptorch_models) == [
        RunningStage.TRAINING, RunningStage.VALIDATING
    ]

    for fn, stage in (
        (TrainerFn.VALIDATING, RunningStage.VALIDATING),
        (TrainerFn.TESTING, RunningStage.TESTING),
        (TrainerFn.PREDICTING, RunningStage.PREDICTING),
    ):
        trainer.state.fn = fn
        trainer.state.stage = stage
        trainer.strategy.setup(trainer)
        assert list(trainer.strategy.poptorch_models) == [stage]
Esempio n. 2
0
def test_replication_factor(tmpdir):
    """Ensure if the user passes manual poptorch Options with custom parameters set, we set them correctly in the
    dataloaders."""

    plugin = IPUStrategy()
    trainer = Trainer(accelerator="ipu",
                      devices=2,
                      default_root_dir=tmpdir,
                      fast_dev_run=True,
                      strategy=plugin)
    assert isinstance(trainer.accelerator, IPUAccelerator)
    assert trainer.num_devices == 2
    assert trainer.strategy.replication_factor == 2

    model = BoringModel()
    training_opts = poptorch.Options()
    inference_opts = poptorch.Options()
    training_opts.replicationFactor(8)
    inference_opts.replicationFactor(7)
    plugin = IPUStrategy(inference_opts=inference_opts,
                         training_opts=training_opts)

    trainer = Trainer(default_root_dir=tmpdir,
                      accelerator="ipu",
                      devices=1,
                      strategy=plugin)
    trainer.optimizers = model.configure_optimizers()[0]
    plugin.model = model
    model.trainer = trainer
    trainer.state.fn = TrainerFn.FITTING
    trainer.strategy.setup(trainer)

    trainer.state.stage = RunningStage.TRAINING
    assert trainer.strategy.replication_factor == 8
    trainer.state.stage = RunningStage.VALIDATING
    assert trainer.strategy.replication_factor == 7

    for fn, stage in (
        (TrainerFn.VALIDATING, RunningStage.VALIDATING),
        (TrainerFn.TESTING, RunningStage.TESTING),
        (TrainerFn.PREDICTING, RunningStage.PREDICTING),
    ):
        trainer.state.fn = fn
        trainer.state.stage = stage
        trainer.strategy.setup(trainer)
        assert trainer.strategy.replication_factor == 7