Ejemplo n.º 1
0
def test_rpc_sequential_plugin_manual_amp(tmpdir, args=None):
    model = SequentialModelRPCManual()
    trainer = Trainer(
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        gpus=2,
        precision=16,
        amp_backend="native",
        distributed_backend="ddp",
        plugins=[RPCSequentialPlugin(balance=[2, 1])],
    )
    with pytest.raises(
        MisconfigurationException,
        match='`RPCSequentialPlugin` is currently not supported in Automatic Mixed Precision'
    ):
        trainer.fit(model)
Ejemplo n.º 2
0
def test_rpc_sequential_plugin_automatic(tmpdir, args=None):
    model = SequentialModelRPCAutomatic()
    trainer = Trainer(
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        gpus=2,
        distributed_backend="ddp",
        plugins=[RPCSequentialPlugin(balance=[2, 1])],
    )

    trainer.fit(model)

    if torch_distrib.is_initialized() and torch_distrib.get_rank() == 0:
        assert len(trainer.dev_debugger.pbar_added_metrics) > 0

    if trainer.accelerator_backend.rpc_enabled:
        # Called at the end of trainer to ensure all processes are killed
        trainer.accelerator_backend.training_type_plugin.exit_rpc_process()
Ejemplo n.º 3
0
def test_rpc_sequential_plugin_with_wrong_balance(tmpdir, args=None):
    model = SequentialModelRPCAutomatic()
    trainer = Trainer(
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        gpus=2,
        distributed_backend="ddp",
        plugins=[RPCSequentialPlugin(balance=[2, 2])],
    )

    with pytest.raises(
        MisconfigurationException, match="The provided balance sum: 4 does not match your Sequential length: 3"
    ):
        trainer.fit(model)

    if trainer.accelerator_backend.rpc_enabled:
        # Called at the end of trainer to ensure all processes are killed
        trainer.accelerator_backend.training_type_plugin.exit_rpc_process()