Exemplo n.º 1
0
def test_cli_config_overwrite(tmpdir):
    trainer_defaults = {"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1}

    with mock.patch("sys.argv", ["any.py"]):
        LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
    with mock.patch("sys.argv", ["any.py"]), pytest.raises(RuntimeError, match="Aborting to avoid overwriting"):
        LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
    with mock.patch("sys.argv", ["any.py"]):
        LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)
Exemplo n.º 2
0
def test_lightning_cli_submodules(tmpdir):
    class MainModule(BoringModel):
        def __init__(
            self,
            submodule1: LightningModule,
            submodule2: LightningModule,
            main_param: int = 1,
        ):
            super().__init__()
            self.submodule1 = submodule1
            self.submodule2 = submodule2

    config = """model:
        main_param: 2
        submodule1:
            class_path: tests.helpers.boring_model.BoringModel
        submodule2:
            class_path: tests.helpers.boring_model.BoringModel
    """
    config_path = tmpdir / "config.yaml"
    with open(config_path, "w") as f:
        f.write(config)

    cli_args = [
        f"--trainer.default_root_dir={tmpdir}",
        "--trainer.max_epochs=1",
        f"--config={str(config_path)}",
    ]

    with mock.patch("sys.argv", ["any.py"] + cli_args):
        cli = LightningCLI(MainModule)

    assert cli.config["model"]["main_param"] == 2
    assert isinstance(cli.model.submodule1, BoringModel)
    assert isinstance(cli.model.submodule2, BoringModel)
Exemplo n.º 3
0
def test_lightning_cli_config_and_subclass_mode(tmpdir):

    config = dict(
        model=dict(class_path="tests.helpers.boring_model.BoringModel"),
        data=dict(class_path="tests.helpers.boring_model.BoringDataModule", init_args=dict(data_dir=str(tmpdir))),
        trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None),
    )
    config_path = tmpdir / "config.yaml"
    with open(config_path, "w") as f:
        f.write(yaml.dump(config))

    with mock.patch("sys.argv", ["any.py", "--config", str(config_path)]):
        cli = LightningCLI(
            BoringModel,
            BoringDataModule,
            subclass_mode_model=True,
            subclass_mode_data=True,
            trainer_defaults={"callbacks": LearningRateMonitor()},
        )

    config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml"
    assert os.path.isfile(config_path)
    with open(config_path) as f:
        config = yaml.safe_load(f.read())
    assert config["model"] == cli.config["model"]
    assert config["data"] == cli.config["data"]
    assert config["trainer"] == cli.config["trainer"]
Exemplo n.º 4
0
def any_model_any_data_cli():
    LightningCLI(
        LightningModule,
        LightningDataModule,
        subclass_mode_model=True,
        subclass_mode_data=True,
    )
Exemplo n.º 5
0
def test_lightning_cli_args_callbacks(tmpdir):

    callbacks = [
        dict(
            class_path="pytorch_lightning.callbacks.LearningRateMonitor",
            init_args=dict(logging_interval="epoch", log_momentum=True),
        ),
        dict(class_path="pytorch_lightning.callbacks.ModelCheckpoint", init_args=dict(monitor="NAME")),
    ]

    class TestModel(BoringModel):
        def on_fit_start(self):
            callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)]
            assert len(callback) == 1
            assert callback[0].logging_interval == "epoch"
            assert callback[0].log_momentum is True
            callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
            assert len(callback) == 1
            assert callback[0].monitor == "NAME"
            self.trainer.ran_asserts = True

    with mock.patch("sys.argv", ["any.py", f"--trainer.callbacks={json.dumps(callbacks)}"]):
        cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True))

    assert cli.trainer.ran_asserts
Exemplo n.º 6
0
def test_lightning_cli(trainer_class, model_class, monkeypatch):
    """Test that LightningCLI correctly instantiates model, trainer and calls fit."""

    expected_model = dict(model_param=7)
    expected_trainer = dict(limit_train_batches=100)

    def fit(trainer, model):
        for k, v in expected_model.items():
            assert getattr(model, k) == v
        for k, v in expected_trainer.items():
            assert getattr(trainer, k) == v
        save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)]
        assert len(save_callback) == 1
        save_callback[0].on_train_start(trainer, model)

    def on_train_start(callback, trainer, _):
        config_dump = callback.parser.dump(callback.config, skip_none=False)
        for k, v in expected_model.items():
            assert f"  {k}: {v}" in config_dump
        for k, v in expected_trainer.items():
            assert f"  {k}: {v}" in config_dump
        trainer.ran_asserts = True

    monkeypatch.setattr(Trainer, "fit", fit)
    monkeypatch.setattr(SaveConfigCallback, "on_train_start", on_train_start)

    with mock.patch("sys.argv", ["any.py", "--model.model_param=7", "--trainer.limit_train_batches=100"]):
        cli = LightningCLI(model_class, trainer_class=trainer_class, save_config_callback=SaveConfigCallback)
        assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts
Exemplo n.º 7
0
def test_lightning_cli_args_cluster_environments(tmpdir):
    plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")]

    class TestModel(BoringModel):
        def on_fit_start(self):
            # Ensure SLURMEnvironment is set, instead of default LightningEnvironment
            assert isinstance(self.trainer.accelerator_connector._cluster_environment, SLURMEnvironment)
            self.trainer.ran_asserts = True

    with mock.patch("sys.argv", ["any.py", f"--trainer.plugins={json.dumps(plugins)}"]):
        cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True))

    assert cli.trainer.ran_asserts
Exemplo n.º 8
0
def test_lightning_cli_save_config_cases(tmpdir):

    config_path = tmpdir / "config.yaml"
    cli_args = [
        f"--trainer.default_root_dir={tmpdir}",
        "--trainer.logger=False",
        "--trainer.fast_dev_run=1",
    ]

    # With fast_dev_run!=False config should not be saved
    with mock.patch("sys.argv", ["any.py"] + cli_args):
        LightningCLI(BoringModel)
    assert not os.path.isfile(config_path)

    # With fast_dev_run==False config should be saved
    cli_args[-1] = "--trainer.max_epochs=1"
    with mock.patch("sys.argv", ["any.py"] + cli_args):
        LightningCLI(BoringModel)
    assert os.path.isfile(config_path)

    # If run again on same directory exception should be raised since config file already exists
    with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.raises(RuntimeError):
        LightningCLI(BoringModel)
Exemplo n.º 9
0
def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs):
    with mock.patch("sys.argv", ["any.py"]), pytest.raises(KeyboardInterrupt):
        LightningCLI(
            EarlyExitTestModel,
            trainer_defaults={
                "default_root_dir": str(tmpdir),
                "logger": logger,
                "max_steps": 1,
                "max_epochs": 1,
                **trainer_kwargs,
            },
        )
    if logger:
        config_dir = tmpdir / "lightning_logs"
        # no more version dirs should get created
        assert os.listdir(config_dir) == ["version_0"]
        config_path = config_dir / "version_0" / "config.yaml"
    else:
        config_path = tmpdir / "config.yaml"
    assert os.path.isfile(config_path)
Exemplo n.º 10
0
def test_lightning_cli_torch_modules(tmpdir):
    class TestModule(BoringModel):
        def __init__(
            self,
            activation: torch.nn.Module = None,
            transform: Optional[List[torch.nn.Module]] = None,
        ):
            super().__init__()
            self.activation = activation
            self.transform = transform

    config = """model:
        activation:
          class_path: torch.nn.LeakyReLU
          init_args:
            negative_slope: 0.2
        transform:
          - class_path: torchvision.transforms.Resize
            init_args:
              size: 64
          - class_path: torchvision.transforms.CenterCrop
            init_args:
              size: 64
    """
    config_path = tmpdir / "config.yaml"
    with open(config_path, "w") as f:
        f.write(config)

    cli_args = [
        f"--trainer.default_root_dir={tmpdir}",
        "--trainer.max_epochs=1",
        f"--config={str(config_path)}",
    ]

    with mock.patch("sys.argv", ["any.py"] + cli_args):
        cli = LightningCLI(TestModule)

    assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
    assert cli.model.activation.negative_slope == 0.2
    assert len(cli.model.transform) == 2
    assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)
Exemplo n.º 11
0
def test_lightning_cli_args(tmpdir):

    cli_args = [
        f"--data.data_dir={tmpdir}",
        f"--trainer.default_root_dir={tmpdir}",
        "--trainer.max_epochs=1",
        "--trainer.weights_summary=null",
        "--seed_everything=1234",
    ]

    with mock.patch("sys.argv", ["any.py"] + cli_args):
        cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={"callbacks": [LearningRateMonitor()]})

    assert cli.config["seed_everything"] == 1234
    config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml"
    assert os.path.isfile(config_path)
    with open(config_path) as f:
        config = yaml.safe_load(f.read())
    assert "model" not in config and "model" not in cli.config  # no arguments to include
    assert config["data"] == cli.config["data"]
    assert config["trainer"] == cli.config["trainer"]