def test_rpc_sequential_plugin_manual_amp(tmpdir, args=None): model = SequentialModelRPCManual() trainer = Trainer( max_epochs=2, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, gpus=2, precision=16, amp_backend="native", distributed_backend="ddp", plugins=[RPCSequentialPlugin(balance=[2, 1])], ) with pytest.raises( MisconfigurationException, match='`RPCSequentialPlugin` is currently not supported in Automatic Mixed Precision' ): trainer.fit(model)
def test_rpc_sequential_plugin_automatic(tmpdir, args=None): model = SequentialModelRPCAutomatic() trainer = Trainer( max_epochs=2, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, gpus=2, distributed_backend="ddp", plugins=[RPCSequentialPlugin(balance=[2, 1])], ) trainer.fit(model) if torch_distrib.is_initialized() and torch_distrib.get_rank() == 0: assert len(trainer.dev_debugger.pbar_added_metrics) > 0 if trainer.accelerator_backend.rpc_enabled: # Called at the end of trainer to ensure all processes are killed trainer.accelerator_backend.training_type_plugin.exit_rpc_process()
def test_rpc_sequential_plugin_with_wrong_balance(tmpdir, args=None): model = SequentialModelRPCAutomatic() trainer = Trainer( max_epochs=2, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, gpus=2, distributed_backend="ddp", plugins=[RPCSequentialPlugin(balance=[2, 2])], ) with pytest.raises( MisconfigurationException, match="The provided balance sum: 4 does not match your Sequential length: 3" ): trainer.fit(model) if trainer.accelerator_backend.rpc_enabled: # Called at the end of trainer to ensure all processes are killed trainer.accelerator_backend.training_type_plugin.exit_rpc_process()