Esempio n. 1
0
def test_lightning_cli_parse_kwargs_with_subcommands(tmpdir):
    fit_config = {"trainer": {"limit_train_batches": 2}}
    fit_config_path = tmpdir / "fit.yaml"
    fit_config_path.write_text(str(fit_config), "utf8")

    validate_config = {"trainer": {"limit_val_batches": 3}}
    validate_config_path = tmpdir / "validate.yaml"
    validate_config_path.write_text(str(validate_config), "utf8")

    parser_kwargs = {
        "fit": {
            "default_config_files": [str(fit_config_path)]
        },
        "validate": {
            "default_config_files": [str(validate_config_path)]
        },
    }

    with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch(
            "pytorch_lightning.Trainer.fit", autospec=True) as fit_mock:
        cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs)
    fit_mock.assert_called()
    assert cli.trainer.limit_train_batches == 2
    assert cli.trainer.limit_val_batches == 1.0

    with mock.patch("sys.argv", ["any.py", "validate"]), mock.patch(
            "pytorch_lightning.Trainer.validate",
            autospec=True) as validate_mock:
        cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs)
    validate_mock.assert_called()
    assert cli.trainer.limit_train_batches == 1.0
    assert cli.trainer.limit_val_batches == 3
def test_cli_config_overwrite(tmpdir):
    trainer_defaults = {"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1}

    with mock.patch("sys.argv", ["any.py"]):
        LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
    with mock.patch("sys.argv", ["any.py"]), pytest.raises(RuntimeError, match="Aborting to avoid overwriting"):
        LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
    with mock.patch("sys.argv", ["any.py"]):
        LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)
Esempio n. 3
0
def test_cli_config_overwrite(tmpdir):
    trainer_defaults = {'default_root_dir': str(tmpdir), 'logger': False, 'max_steps': 1, 'max_epochs': 1}

    with mock.patch('sys.argv', ['any.py']):
        LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
    with mock.patch('sys.argv', ['any.py']), pytest.raises(RuntimeError, match='Aborting to avoid overwriting'):
        LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
    with mock.patch('sys.argv', ['any.py']):
        LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)
Esempio n. 4
0
def test_lightning_cli_config_before_and_after_subcommand():
    config1 = {
        "test": {
            "trainer": {
                "limit_test_batches": 1
            },
            "verbose": True,
            "ckpt_path": "foobar"
        }
    }
    config2 = {
        "trainer": {
            "fast_dev_run": 1
        },
        "verbose": False,
        "ckpt_path": "foobar"
    }
    with mock.patch(
            "sys.argv",
        ["any.py", f"--config={config1}", "test",
         f"--config={config2}"]), mock.patch("pytorch_lightning.Trainer.test",
                                             autospec=True) as test_mock:
        cli = LightningCLI(BoringModel)

    test_mock.assert_called_once_with(cli.trainer,
                                      model=cli.model,
                                      verbose=False,
                                      ckpt_path="foobar")
    assert cli.trainer.limit_test_batches == 1
    assert cli.trainer.fast_dev_run == 1
Esempio n. 5
0
def test_lightning_cli_submodules(tmpdir):
    class MainModule(BoringModel):
        def __init__(self,
                     submodule1: LightningModule,
                     submodule2: LightningModule,
                     main_param: int = 1):
            super().__init__()
            self.submodule1 = submodule1
            self.submodule2 = submodule2

    config = """model:
        main_param: 2
        submodule1:
            class_path: tests.helpers.BoringModel
        submodule2:
            class_path: tests.helpers.BoringModel
    """
    config_path = tmpdir / "config.yaml"
    with open(config_path, "w") as f:
        f.write(config)

    cli_args = [
        f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1",
        f"--config={str(config_path)}"
    ]

    with mock.patch("sys.argv", ["any.py"] + cli_args):
        cli = LightningCLI(MainModule)

    assert cli.config["model"]["main_param"] == 2
    assert isinstance(cli.model.submodule1, BoringModel)
    assert isinstance(cli.model.submodule2, BoringModel)
Esempio n. 6
0
def test_lightning_cli_submodules(tmpdir):
    class MainModule(BoringModel):
        def __init__(
            self,
            submodule1: LightningModule,
            submodule2: LightningModule,
            main_param: int = 1,
        ):
            super().__init__()
            self.submodule1 = submodule1
            self.submodule2 = submodule2

    config = """model:
        main_param: 2
        submodule1:
            class_path: tests.helpers.BoringModel
        submodule2:
            class_path: tests.helpers.BoringModel
    """
    config_path = tmpdir / 'config.yaml'
    with open(config_path, 'w') as f:
        f.write(config)

    cli_args = [
        f'--trainer.default_root_dir={tmpdir}',
        '--trainer.max_epochs=1',
        f'--config={str(config_path)}',
    ]

    with mock.patch('sys.argv', ['any.py'] + cli_args):
        cli = LightningCLI(MainModule)

    assert cli.config['model']['main_param'] == 2
    assert isinstance(cli.model.submodule1, BoringModel)
    assert isinstance(cli.model.submodule2, BoringModel)
Esempio n. 7
0
def test_lightning_cli_args(tmpdir):

    cli_args = [
        f'--data.data_dir={tmpdir}',
        f'--trainer.default_root_dir={tmpdir}',
        '--trainer.max_epochs=1',
        '--trainer.weights_summary=null',
        '--seed_everything=1234',
    ]

    with mock.patch('sys.argv', ['any.py'] + cli_args):
        cli = LightningCLI(
            BoringModel,
            BoringDataModule,
            trainer_defaults={'callbacks': [LearningRateMonitor()]})

    assert cli.fit_result == 1
    assert cli.config['seed_everything'] == 1234
    config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml'
    assert os.path.isfile(config_path)
    with open(config_path) as f:
        config = yaml.safe_load(f.read())
    assert 'model' not in config and 'model' not in cli.config  # no arguments to include
    assert config['data'] == cli.config['data']
    assert config['trainer'] == cli.config['trainer']
Esempio n. 8
0
def test_lightning_cli_config_and_subclass_mode(tmpdir):

    config = dict(
        model=dict(class_path="tests.helpers.BoringModel"),
        data=dict(class_path="tests.helpers.BoringDataModule",
                  init_args=dict(data_dir=str(tmpdir))),
        trainer=dict(default_root_dir=str(tmpdir),
                     max_epochs=1,
                     weights_summary=None),
    )
    config_path = tmpdir / "config.yaml"
    with open(config_path, "w") as f:
        f.write(yaml.dump(config))

    with mock.patch("sys.argv", ["any.py", "--config", str(config_path)]):
        cli = LightningCLI(
            BoringModel,
            BoringDataModule,
            subclass_mode_model=True,
            subclass_mode_data=True,
            trainer_defaults={"callbacks": LearningRateMonitor()},
        )

    config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml"
    assert os.path.isfile(config_path)
    with open(config_path) as f:
        config = yaml.safe_load(f.read())
    assert config["model"] == cli.config["model"]
    assert config["data"] == cli.config["data"]
    assert config["trainer"] == cli.config["trainer"]
Esempio n. 9
0
def test_lightning_cli_args_callbacks(tmpdir):

    callbacks = [
        dict(
            class_path='pytorch_lightning.callbacks.LearningRateMonitor',
            init_args=dict(logging_interval='epoch', log_momentum=True)
        ),
        dict(class_path='pytorch_lightning.callbacks.ModelCheckpoint', init_args=dict(monitor='NAME')),
    ]

    class TestModel(BoringModel):

        def on_fit_start(self):
            callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)]
            assert len(callback) == 1
            assert callback[0].logging_interval == 'epoch'
            assert callback[0].log_momentum is True
            callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
            assert len(callback) == 1
            assert callback[0].monitor == 'NAME'
            self.trainer.ran_asserts = True

    with mock.patch('sys.argv', ['any.py', f'--trainer.callbacks={json.dumps(callbacks)}']):
        cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True))

    assert cli.trainer.ran_asserts
Esempio n. 10
0
def test_lightning_cli_torch_modules(tmpdir):
    class TestModule(BoringModel):
        def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None):
            super().__init__()
            self.activation = activation
            self.transform = transform

    config = """model:
        activation:
          class_path: torch.nn.LeakyReLU
          init_args:
            negative_slope: 0.2
        transform:
          - class_path: torchvision.transforms.Resize
            init_args:
              size: 64
          - class_path: torchvision.transforms.CenterCrop
            init_args:
              size: 64
    """
    config_path = tmpdir / "config.yaml"
    with open(config_path, "w") as f:
        f.write(config)

    cli_args = [f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", f"--config={str(config_path)}"]

    with mock.patch("sys.argv", ["any.py"] + cli_args):
        cli = LightningCLI(TestModule)

    assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
    assert cli.model.activation.negative_slope == 0.2
    assert len(cli.model.transform) == 2
    assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)
Esempio n. 11
0
def test_lightning_cli_config_and_subclass_mode(tmpdir):

    config = dict(
        model=dict(class_path='tests.helpers.BoringModel'),
        data=dict(class_path='tests.helpers.BoringDataModule', init_args=dict(data_dir=str(tmpdir))),
        trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None)
    )
    config_path = tmpdir / 'config.yaml'
    with open(config_path, 'w') as f:
        f.write(yaml.dump(config))

    with mock.patch('sys.argv', ['any.py', '--config', str(config_path)]):
        cli = LightningCLI(
            BoringModel,
            BoringDataModule,
            subclass_mode_model=True,
            subclass_mode_data=True,
            trainer_defaults={'callbacks': LearningRateMonitor()}
        )

    config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml'
    assert os.path.isfile(config_path)
    with open(config_path) as f:
        config = yaml.safe_load(f.read())
    assert config['model'] == cli.config['model']
    assert config['data'] == cli.config['data']
    assert config['trainer'] == cli.config['trainer']
Esempio n. 12
0
def any_model_any_data_cli():
    LightningCLI(
        LightningModule,
        LightningDataModule,
        subclass_mode_model=True,
        subclass_mode_data=True,
    )
Esempio n. 13
0
def test_lightning_cli(trainer_class, model_class, monkeypatch):
    """Test that LightningCLI correctly instantiates model, trainer and calls fit."""

    expected_model = dict(model_param=7)
    expected_trainer = dict(limit_train_batches=100)

    def fit(trainer, model):
        for k, v in expected_model.items():
            assert getattr(model, k) == v
        for k, v in expected_trainer.items():
            assert getattr(trainer, k) == v
        save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)]
        assert len(save_callback) == 1
        save_callback[0].on_train_start(trainer, model)

    def on_train_start(callback, trainer, _):
        config_dump = callback.parser.dump(callback.config, skip_none=False)
        for k, v in expected_model.items():
            assert f"  {k}: {v}" in config_dump
        for k, v in expected_trainer.items():
            assert f"  {k}: {v}" in config_dump
        trainer.ran_asserts = True

    monkeypatch.setattr(Trainer, "fit", fit)
    monkeypatch.setattr(SaveConfigCallback, "on_train_start", on_train_start)

    with mock.patch("sys.argv", ["any.py", "--model.model_param=7", "--trainer.limit_train_batches=100"]):
        cli = LightningCLI(model_class, trainer_class=trainer_class, save_config_callback=SaveConfigCallback)
        assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts
Esempio n. 14
0
def cli_main():
    cli = LightningCLI(LitClassifier,
                       MyDataModule,
                       seed_everything_default=1234)
    cli.trainer.test(cli.model, datamodule=cli.datamodule)
    predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
    print(predictions[0])
Esempio n. 15
0
def cli_main():
    if len(sys.argv) == 1:
        sys.argv += DEFAULT_CMD_LINE

    LightningCLI(
        ModelToProfile, CIFAR10DataModule, save_config_overwrite=True, trainer_defaults={"profiler": PyTorchProfiler()}
    )
Esempio n. 16
0
def test_lightning_cli_args(tmpdir):

    cli_args = [
        "fit",
        f"--data.data_dir={tmpdir}",
        f"--trainer.default_root_dir={tmpdir}",
        "--trainer.max_epochs=1",
        "--trainer.weights_summary=null",
        "--seed_everything=1234",
    ]

    with mock.patch("sys.argv", ["any.py"] + cli_args):
        cli = LightningCLI(
            BoringModel,
            BoringDataModule,
            trainer_defaults={"callbacks": [LearningRateMonitor()]})

    config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml"
    assert os.path.isfile(config_path)
    with open(config_path) as f:
        loaded_config = yaml.safe_load(f.read())

    loaded_config = loaded_config["fit"]
    cli_config = cli.config["fit"]

    assert cli_config["seed_everything"] == 1234
    assert "model" not in loaded_config and "model" not in cli_config  # no arguments to include
    assert loaded_config["data"] == cli_config["data"]
    assert loaded_config["trainer"] == cli_config["trainer"]
Esempio n. 17
0
def test_lightning_cli(cli_args, expected_model, expected_trainer, monkeypatch):
    """Test that LightningCLI correctly instantiates model, trainer and calls fit."""

    def fit(trainer, model):
        for k, v in model.expected_model.items():
            assert getattr(model, k) == v
        for k, v in model.expected_trainer.items():
            assert getattr(trainer, k) == v
        save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)]
        assert len(save_callback) == 1
        save_callback[0].on_train_start(trainer, model)

    def on_train_start(callback, trainer, model):
        config_dump = callback.parser.dump(callback.config, skip_none=False)
        for k, v in model.expected_model.items():
            assert f'  {k}: {v}' in config_dump
        for k, v in model.expected_trainer.items():
            assert f'  {k}: {v}' in config_dump
        trainer.ran_asserts = True

    monkeypatch.setattr(Trainer, 'fit', fit)
    monkeypatch.setattr(SaveConfigCallback, 'on_train_start', on_train_start)

    class TestModel(LightningModule):

        def __init__(self, model_param: int):
            super().__init__()
            self.model_param = model_param

    TestModel.expected_model = expected_model
    TestModel.expected_trainer = expected_trainer

    with mock.patch('sys.argv', ['any.py'] + cli_args):
        cli = LightningCLI(TestModel, trainer_class=Trainer, save_config_callback=SaveConfigCallback)
        assert hasattr(cli.trainer, 'ran_asserts') and cli.trainer.ran_asserts
Esempio n. 18
0
def test_lightning_cli_run():
    with mock.patch("sys.argv", ["any.py"]):
        cli = LightningCLI(BoringModel, run=False)
    assert cli.trainer.global_step == 0
    assert isinstance(cli.trainer, Trainer)
    assert isinstance(cli.model, LightningModule)

    with mock.patch("sys.argv", ["any.py", "fit"]):
        cli = LightningCLI(BoringModel,
                           trainer_defaults={
                               "max_steps": 1,
                               "max_epochs": 1
                           })
    assert cli.trainer.global_step == 1
    assert isinstance(cli.trainer, Trainer)
    assert isinstance(cli.model, LightningModule)
def cli_main():
    if not _DALI_AVAILABLE:
        return

    cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
    result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
    print(result)
Esempio n. 20
0
def cli_main():
    cli = LightningCLI(LitClassifier,
                       MNISTDataModule,
                       seed_everything_default=1234,
                       save_config_overwrite=True)
    cli.trainer.fit(cli.model, datamodule=cli.datamodule)
    cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
Esempio n. 21
0
def cli_main():
    cli = LightningCLI(
        LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
    )
    cli.trainer.fit(cli.model, datamodule=cli.datamodule)
    cli.trainer.test(ckpt_path="best")
    predictions = cli.trainer.predict(ckpt_path="best")
    print(predictions[0])
def cli_main():
    # The LightningCLI removes all the boilerplate associated with arguments parsing. This is purely optional.
    cli = LightningCLI(ImageClassifier,
                       seed_everything_default=42,
                       save_config_overwrite=True,
                       run=False)
    cli.trainer.fit(cli.model, datamodule=cli.datamodule)
    cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
Esempio n. 23
0
def test_lightning_cli_disabled_run(run):
    with mock.patch(
            "sys.argv",
        ["any.py"]), mock.patch("pytorch_lightning.Trainer.fit") as fit_mock:
        cli = LightningCLI(BoringModel, run=run)
    fit_mock.call_count == run
    assert isinstance(cli.trainer, Trainer)
    assert isinstance(cli.model, LightningModule)
Esempio n. 24
0
def test_lightning_cli_config_before_subcommand_two_configs():
    config1 = {
        "validate": {
            "trainer": {
                "limit_val_batches": 1
            },
            "verbose": False,
            "ckpt_path": "barfoo"
        }
    }
    config2 = {
        "test": {
            "trainer": {
                "limit_test_batches": 1
            },
            "verbose": True,
            "ckpt_path": "foobar"
        }
    }

    with mock.patch(
            "sys.argv",
        ["any.py", f"--config={config1}", f"--config={config2}", "test"
         ]), mock.patch("pytorch_lightning.Trainer.test",
                        autospec=True) as test_mock:
        cli = LightningCLI(BoringModel)

    test_mock.assert_called_once_with(cli.trainer,
                                      model=cli.model,
                                      verbose=True,
                                      ckpt_path="foobar")
    assert cli.trainer.limit_test_batches == 1

    with mock.patch(
            "sys.argv",
        ["any.py", f"--config={config1}", f"--config={config2}",
         "validate"]), mock.patch("pytorch_lightning.Trainer.validate",
                                  autospec=True) as validate_mock:
        cli = LightningCLI(BoringModel)

    validate_mock.assert_called_once_with(cli.trainer,
                                          cli.model,
                                          verbose=False,
                                          ckpt_path="barfoo")
    assert cli.trainer.limit_val_batches == 1
Esempio n. 25
0
def cli_main():
    if not _DALI_AVAILABLE:
        return

    cli = LightningCLI(LitClassifier,
                       MyDataModule,
                       seed_everything_default=1234,
                       save_config_overwrite=True)
    cli.trainer.test(cli.model, datamodule=cli.datamodule)
Esempio n. 26
0
def test_lightning_cli_reinstantiate_trainer():
    with mock.patch("sys.argv", ["any.py"]):
        cli = LightningCLI(BoringModel, run=False)
    assert cli.trainer.max_epochs == 1000

    class TestCallback(Callback):
        ...

    # make sure a new trainer can be easily created
    trainer = cli.instantiate_trainer(max_epochs=123,
                                      callbacks=[TestCallback()])
    # the new config is used
    assert trainer.max_epochs == 123
    assert {c.__class__
            for c in trainer.callbacks
            } == {c.__class__
                  for c in cli.trainer.callbacks}.union({TestCallback})
    # the existing config is not updated
    assert cli.config_init["trainer"]["max_epochs"] is None
Esempio n. 27
0
def test_lightning_cli_save_config_cases(tmpdir):

    config_path = tmpdir / "config.yaml"
    cli_args = [f"--trainer.default_root_dir={tmpdir}", "--trainer.logger=False", "--trainer.fast_dev_run=1"]

    # With fast_dev_run!=False config should not be saved
    with mock.patch("sys.argv", ["any.py"] + cli_args):
        LightningCLI(BoringModel)
    assert not os.path.isfile(config_path)

    # With fast_dev_run==False config should be saved
    cli_args[-1] = "--trainer.max_epochs=1"
    with mock.patch("sys.argv", ["any.py"] + cli_args):
        LightningCLI(BoringModel)
    assert os.path.isfile(config_path)

    # If run again on same directory exception should be raised since config file already exists
    with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.raises(RuntimeError):
        LightningCLI(BoringModel)
Esempio n. 28
0
def test_lightning_cli_subcommands():
    subcommands = LightningCLI.subcommands()
    trainer = Trainer()
    for subcommand, exclude in subcommands.items():
        fn = getattr(trainer, subcommand)
        parameters = list(inspect.signature(fn).parameters)
        for e in exclude:
            # if this fails, it's because the parameter has been removed from the associated `Trainer` function
            # and the `LightningCLI` subcommand exclusion list needs to be updated
            assert e in parameters
Esempio n. 29
0
def test_lightning_cli_args_cluster_environments(tmpdir):
    plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")]

    class TestModel(BoringModel):
        def on_fit_start(self):
            # Ensure SLURMEnvironment is set, instead of default LightningEnvironment
            assert isinstance(self.trainer.accelerator_connector._cluster_environment, SLURMEnvironment)
            self.trainer.ran_asserts = True

    with mock.patch("sys.argv", ["any.py", f"--trainer.plugins={json.dumps(plugins)}"]):
        cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True))

    assert cli.trainer.ran_asserts
Esempio n. 30
0
def cli_main():
    cli = LightningCLI(
        LitAutoEncoder,
        MyDataModule,
        seed_everything_default=1234,
        save_config_overwrite=True,
        run=False,  # used to de-activate automatic fitting.
        trainer_defaults={"callbacks": ImageSampler(), "max_epochs": 10},
    )
    cli.trainer.fit(cli.model, datamodule=cli.datamodule)
    cli.trainer.test(ckpt_path="best")
    predictions = cli.trainer.predict(ckpt_path="best")
    print(predictions[0])