def test_replication_factor(tmpdir): """Ensure if the user passes manual poptorch Options with custom parameters set, we set them correctly in the dataloaders.""" plugin = IPUPlugin() trainer = Trainer(ipus=2, default_root_dir=tmpdir, fast_dev_run=True, strategy=plugin) assert trainer.ipus == 2 assert trainer.training_type_plugin.replication_factor == 2 model = BoringModel() training_opts = poptorch.Options() inference_opts = poptorch.Options() training_opts.replicationFactor(8) inference_opts.replicationFactor(7) plugin = IPUPlugin(inference_opts=inference_opts, training_opts=training_opts) trainer = Trainer(default_root_dir=tmpdir, ipus=1, strategy=plugin) trainer.optimizers = model.configure_optimizers()[0] plugin.model = model model.trainer = trainer trainer.state.fn = TrainerFn.FITTING trainer.training_type_plugin.pre_dispatch() trainer.state.stage = RunningStage.TRAINING assert trainer.training_type_plugin.replication_factor == 8 trainer.state.stage = RunningStage.VALIDATING assert trainer.training_type_plugin.replication_factor == 7 for fn, stage in ( (TrainerFn.VALIDATING, RunningStage.VALIDATING), (TrainerFn.TESTING, RunningStage.TESTING), (TrainerFn.PREDICTING, RunningStage.PREDICTING), ): trainer.state.fn = fn trainer.state.stage = stage trainer.training_type_plugin.pre_dispatch() assert trainer.training_type_plugin.replication_factor == 7
def test_poptorch_models_at_different_stages(tmpdir): plugin = IPUPlugin() trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, ipus=8) model = BoringModel() model.trainer = trainer plugin.model = model trainer.optimizers = model.configure_optimizers()[0] trainer.state.fn = TrainerFn.FITTING trainer.training_type_plugin.pre_dispatch() assert list(trainer.training_type_plugin.poptorch_models) == [ RunningStage.TRAINING, RunningStage.VALIDATING ] for fn, stage in ( (TrainerFn.VALIDATING, RunningStage.VALIDATING), (TrainerFn.TESTING, RunningStage.TESTING), (TrainerFn.PREDICTING, RunningStage.PREDICTING), ): trainer.state.fn = fn trainer.state.stage = stage trainer.training_type_plugin.pre_dispatch() assert list(trainer.training_type_plugin.poptorch_models) == [stage]