def test_lightning_cli_parse_kwargs_with_subcommands(tmpdir): fit_config = {"trainer": {"limit_train_batches": 2}} fit_config_path = tmpdir / "fit.yaml" fit_config_path.write_text(str(fit_config), "utf8") validate_config = {"trainer": {"limit_val_batches": 3}} validate_config_path = tmpdir / "validate.yaml" validate_config_path.write_text(str(validate_config), "utf8") parser_kwargs = { "fit": { "default_config_files": [str(fit_config_path)] }, "validate": { "default_config_files": [str(validate_config_path)] }, } with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch( "pytorch_lightning.Trainer.fit", autospec=True) as fit_mock: cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs) fit_mock.assert_called() assert cli.trainer.limit_train_batches == 2 assert cli.trainer.limit_val_batches == 1.0 with mock.patch("sys.argv", ["any.py", "validate"]), mock.patch( "pytorch_lightning.Trainer.validate", autospec=True) as validate_mock: cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs) validate_mock.assert_called() assert cli.trainer.limit_train_batches == 1.0 assert cli.trainer.limit_val_batches == 3
def test_cli_config_overwrite(tmpdir): trainer_defaults = {"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1} with mock.patch("sys.argv", ["any.py"]): LightningCLI(BoringModel, trainer_defaults=trainer_defaults) with mock.patch("sys.argv", ["any.py"]), pytest.raises(RuntimeError, match="Aborting to avoid overwriting"): LightningCLI(BoringModel, trainer_defaults=trainer_defaults) with mock.patch("sys.argv", ["any.py"]): LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)
def test_cli_config_overwrite(tmpdir): trainer_defaults = {'default_root_dir': str(tmpdir), 'logger': False, 'max_steps': 1, 'max_epochs': 1} with mock.patch('sys.argv', ['any.py']): LightningCLI(BoringModel, trainer_defaults=trainer_defaults) with mock.patch('sys.argv', ['any.py']), pytest.raises(RuntimeError, match='Aborting to avoid overwriting'): LightningCLI(BoringModel, trainer_defaults=trainer_defaults) with mock.patch('sys.argv', ['any.py']): LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)
def test_lightning_cli_config_before_and_after_subcommand(): config1 = { "test": { "trainer": { "limit_test_batches": 1 }, "verbose": True, "ckpt_path": "foobar" } } config2 = { "trainer": { "fast_dev_run": 1 }, "verbose": False, "ckpt_path": "foobar" } with mock.patch( "sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), mock.patch("pytorch_lightning.Trainer.test", autospec=True) as test_mock: cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar") assert cli.trainer.limit_test_batches == 1 assert cli.trainer.fast_dev_run == 1
def test_lightning_cli_submodules(tmpdir): class MainModule(BoringModel): def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1): super().__init__() self.submodule1 = submodule1 self.submodule2 = submodule2 config = """model: main_param: 2 submodule1: class_path: tests.helpers.BoringModel submodule2: class_path: tests.helpers.BoringModel """ config_path = tmpdir / "config.yaml" with open(config_path, "w") as f: f.write(config) cli_args = [ f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", f"--config={str(config_path)}" ] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(MainModule) assert cli.config["model"]["main_param"] == 2 assert isinstance(cli.model.submodule1, BoringModel) assert isinstance(cli.model.submodule2, BoringModel)
def test_lightning_cli_submodules(tmpdir): class MainModule(BoringModel): def __init__( self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1, ): super().__init__() self.submodule1 = submodule1 self.submodule2 = submodule2 config = """model: main_param: 2 submodule1: class_path: tests.helpers.BoringModel submodule2: class_path: tests.helpers.BoringModel """ config_path = tmpdir / 'config.yaml' with open(config_path, 'w') as f: f.write(config) cli_args = [ f'--trainer.default_root_dir={tmpdir}', '--trainer.max_epochs=1', f'--config={str(config_path)}', ] with mock.patch('sys.argv', ['any.py'] + cli_args): cli = LightningCLI(MainModule) assert cli.config['model']['main_param'] == 2 assert isinstance(cli.model.submodule1, BoringModel) assert isinstance(cli.model.submodule2, BoringModel)
def test_lightning_cli_args(tmpdir): cli_args = [ f'--data.data_dir={tmpdir}', f'--trainer.default_root_dir={tmpdir}', '--trainer.max_epochs=1', '--trainer.weights_summary=null', '--seed_everything=1234', ] with mock.patch('sys.argv', ['any.py'] + cli_args): cli = LightningCLI( BoringModel, BoringDataModule, trainer_defaults={'callbacks': [LearningRateMonitor()]}) assert cli.fit_result == 1 assert cli.config['seed_everything'] == 1234 config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml' assert os.path.isfile(config_path) with open(config_path) as f: config = yaml.safe_load(f.read()) assert 'model' not in config and 'model' not in cli.config # no arguments to include assert config['data'] == cli.config['data'] assert config['trainer'] == cli.config['trainer']
def test_lightning_cli_config_and_subclass_mode(tmpdir): config = dict( model=dict(class_path="tests.helpers.BoringModel"), data=dict(class_path="tests.helpers.BoringDataModule", init_args=dict(data_dir=str(tmpdir))), trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None), ) config_path = tmpdir / "config.yaml" with open(config_path, "w") as f: f.write(yaml.dump(config)) with mock.patch("sys.argv", ["any.py", "--config", str(config_path)]): cli = LightningCLI( BoringModel, BoringDataModule, subclass_mode_model=True, subclass_mode_data=True, trainer_defaults={"callbacks": LearningRateMonitor()}, ) config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml" assert os.path.isfile(config_path) with open(config_path) as f: config = yaml.safe_load(f.read()) assert config["model"] == cli.config["model"] assert config["data"] == cli.config["data"] assert config["trainer"] == cli.config["trainer"]
def test_lightning_cli_args_callbacks(tmpdir): callbacks = [ dict( class_path='pytorch_lightning.callbacks.LearningRateMonitor', init_args=dict(logging_interval='epoch', log_momentum=True) ), dict(class_path='pytorch_lightning.callbacks.ModelCheckpoint', init_args=dict(monitor='NAME')), ] class TestModel(BoringModel): def on_fit_start(self): callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)] assert len(callback) == 1 assert callback[0].logging_interval == 'epoch' assert callback[0].log_momentum is True callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] assert len(callback) == 1 assert callback[0].monitor == 'NAME' self.trainer.ran_asserts = True with mock.patch('sys.argv', ['any.py', f'--trainer.callbacks={json.dumps(callbacks)}']): cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) assert cli.trainer.ran_asserts
def test_lightning_cli_torch_modules(tmpdir): class TestModule(BoringModel): def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None): super().__init__() self.activation = activation self.transform = transform config = """model: activation: class_path: torch.nn.LeakyReLU init_args: negative_slope: 0.2 transform: - class_path: torchvision.transforms.Resize init_args: size: 64 - class_path: torchvision.transforms.CenterCrop init_args: size: 64 """ config_path = tmpdir / "config.yaml" with open(config_path, "w") as f: f.write(config) cli_args = [f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", f"--config={str(config_path)}"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(TestModule) assert isinstance(cli.model.activation, torch.nn.LeakyReLU) assert cli.model.activation.negative_slope == 0.2 assert len(cli.model.transform) == 2 assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)
def test_lightning_cli_config_and_subclass_mode(tmpdir): config = dict( model=dict(class_path='tests.helpers.BoringModel'), data=dict(class_path='tests.helpers.BoringDataModule', init_args=dict(data_dir=str(tmpdir))), trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None) ) config_path = tmpdir / 'config.yaml' with open(config_path, 'w') as f: f.write(yaml.dump(config)) with mock.patch('sys.argv', ['any.py', '--config', str(config_path)]): cli = LightningCLI( BoringModel, BoringDataModule, subclass_mode_model=True, subclass_mode_data=True, trainer_defaults={'callbacks': LearningRateMonitor()} ) config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml' assert os.path.isfile(config_path) with open(config_path) as f: config = yaml.safe_load(f.read()) assert config['model'] == cli.config['model'] assert config['data'] == cli.config['data'] assert config['trainer'] == cli.config['trainer']
def any_model_any_data_cli(): LightningCLI( LightningModule, LightningDataModule, subclass_mode_model=True, subclass_mode_data=True, )
def test_lightning_cli(trainer_class, model_class, monkeypatch): """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" expected_model = dict(model_param=7) expected_trainer = dict(limit_train_batches=100) def fit(trainer, model): for k, v in expected_model.items(): assert getattr(model, k) == v for k, v in expected_trainer.items(): assert getattr(trainer, k) == v save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)] assert len(save_callback) == 1 save_callback[0].on_train_start(trainer, model) def on_train_start(callback, trainer, _): config_dump = callback.parser.dump(callback.config, skip_none=False) for k, v in expected_model.items(): assert f" {k}: {v}" in config_dump for k, v in expected_trainer.items(): assert f" {k}: {v}" in config_dump trainer.ran_asserts = True monkeypatch.setattr(Trainer, "fit", fit) monkeypatch.setattr(SaveConfigCallback, "on_train_start", on_train_start) with mock.patch("sys.argv", ["any.py", "--model.model_param=7", "--trainer.limit_train_batches=100"]): cli = LightningCLI(model_class, trainer_class=trainer_class, save_config_callback=SaveConfigCallback) assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts
def cli_main(): cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234) cli.trainer.test(cli.model, datamodule=cli.datamodule) predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule) print(predictions[0])
def cli_main(): if len(sys.argv) == 1: sys.argv += DEFAULT_CMD_LINE LightningCLI( ModelToProfile, CIFAR10DataModule, save_config_overwrite=True, trainer_defaults={"profiler": PyTorchProfiler()} )
def test_lightning_cli_args(tmpdir): cli_args = [ "fit", f"--data.data_dir={tmpdir}", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", "--trainer.weights_summary=null", "--seed_everything=1234", ] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI( BoringModel, BoringDataModule, trainer_defaults={"callbacks": [LearningRateMonitor()]}) config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml" assert os.path.isfile(config_path) with open(config_path) as f: loaded_config = yaml.safe_load(f.read()) loaded_config = loaded_config["fit"] cli_config = cli.config["fit"] assert cli_config["seed_everything"] == 1234 assert "model" not in loaded_config and "model" not in cli_config # no arguments to include assert loaded_config["data"] == cli_config["data"] assert loaded_config["trainer"] == cli_config["trainer"]
def test_lightning_cli(cli_args, expected_model, expected_trainer, monkeypatch): """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" def fit(trainer, model): for k, v in model.expected_model.items(): assert getattr(model, k) == v for k, v in model.expected_trainer.items(): assert getattr(trainer, k) == v save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)] assert len(save_callback) == 1 save_callback[0].on_train_start(trainer, model) def on_train_start(callback, trainer, model): config_dump = callback.parser.dump(callback.config, skip_none=False) for k, v in model.expected_model.items(): assert f' {k}: {v}' in config_dump for k, v in model.expected_trainer.items(): assert f' {k}: {v}' in config_dump trainer.ran_asserts = True monkeypatch.setattr(Trainer, 'fit', fit) monkeypatch.setattr(SaveConfigCallback, 'on_train_start', on_train_start) class TestModel(LightningModule): def __init__(self, model_param: int): super().__init__() self.model_param = model_param TestModel.expected_model = expected_model TestModel.expected_trainer = expected_trainer with mock.patch('sys.argv', ['any.py'] + cli_args): cli = LightningCLI(TestModel, trainer_class=Trainer, save_config_callback=SaveConfigCallback) assert hasattr(cli.trainer, 'ran_asserts') and cli.trainer.ran_asserts
def test_lightning_cli_run(): with mock.patch("sys.argv", ["any.py"]): cli = LightningCLI(BoringModel, run=False) assert cli.trainer.global_step == 0 assert isinstance(cli.trainer, Trainer) assert isinstance(cli.model, LightningModule) with mock.patch("sys.argv", ["any.py", "fit"]): cli = LightningCLI(BoringModel, trainer_defaults={ "max_steps": 1, "max_epochs": 1 }) assert cli.trainer.global_step == 1 assert isinstance(cli.trainer, Trainer) assert isinstance(cli.model, LightningModule)
def cli_main(): if not _DALI_AVAILABLE: return cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234) result = cli.trainer.test(cli.model, datamodule=cli.datamodule) print(result)
def cli_main(): cli = LightningCLI(LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True) cli.trainer.fit(cli.model, datamodule=cli.datamodule) cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
def cli_main(): cli = LightningCLI( LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False ) cli.trainer.fit(cli.model, datamodule=cli.datamodule) cli.trainer.test(ckpt_path="best") predictions = cli.trainer.predict(ckpt_path="best") print(predictions[0])
def cli_main(): # The LightningCLI removes all the boilerplate associated with arguments parsing. This is purely optional. cli = LightningCLI(ImageClassifier, seed_everything_default=42, save_config_overwrite=True, run=False) cli.trainer.fit(cli.model, datamodule=cli.datamodule) cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
def test_lightning_cli_disabled_run(run): with mock.patch( "sys.argv", ["any.py"]), mock.patch("pytorch_lightning.Trainer.fit") as fit_mock: cli = LightningCLI(BoringModel, run=run) fit_mock.call_count == run assert isinstance(cli.trainer, Trainer) assert isinstance(cli.model, LightningModule)
def test_lightning_cli_config_before_subcommand_two_configs(): config1 = { "validate": { "trainer": { "limit_val_batches": 1 }, "verbose": False, "ckpt_path": "barfoo" } } config2 = { "test": { "trainer": { "limit_test_batches": 1 }, "verbose": True, "ckpt_path": "foobar" } } with mock.patch( "sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test" ]), mock.patch("pytorch_lightning.Trainer.test", autospec=True) as test_mock: cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") assert cli.trainer.limit_test_batches == 1 with mock.patch( "sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), mock.patch("pytorch_lightning.Trainer.validate", autospec=True) as validate_mock: cli = LightningCLI(BoringModel) validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") assert cli.trainer.limit_val_batches == 1
def cli_main(): if not _DALI_AVAILABLE: return cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True) cli.trainer.test(cli.model, datamodule=cli.datamodule)
def test_lightning_cli_reinstantiate_trainer(): with mock.patch("sys.argv", ["any.py"]): cli = LightningCLI(BoringModel, run=False) assert cli.trainer.max_epochs == 1000 class TestCallback(Callback): ... # make sure a new trainer can be easily created trainer = cli.instantiate_trainer(max_epochs=123, callbacks=[TestCallback()]) # the new config is used assert trainer.max_epochs == 123 assert {c.__class__ for c in trainer.callbacks } == {c.__class__ for c in cli.trainer.callbacks}.union({TestCallback}) # the existing config is not updated assert cli.config_init["trainer"]["max_epochs"] is None
def test_lightning_cli_save_config_cases(tmpdir): config_path = tmpdir / "config.yaml" cli_args = [f"--trainer.default_root_dir={tmpdir}", "--trainer.logger=False", "--trainer.fast_dev_run=1"] # With fast_dev_run!=False config should not be saved with mock.patch("sys.argv", ["any.py"] + cli_args): LightningCLI(BoringModel) assert not os.path.isfile(config_path) # With fast_dev_run==False config should be saved cli_args[-1] = "--trainer.max_epochs=1" with mock.patch("sys.argv", ["any.py"] + cli_args): LightningCLI(BoringModel) assert os.path.isfile(config_path) # If run again on same directory exception should be raised since config file already exists with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.raises(RuntimeError): LightningCLI(BoringModel)
def test_lightning_cli_subcommands(): subcommands = LightningCLI.subcommands() trainer = Trainer() for subcommand, exclude in subcommands.items(): fn = getattr(trainer, subcommand) parameters = list(inspect.signature(fn).parameters) for e in exclude: # if this fails, it's because the parameter has been removed from the associated `Trainer` function # and the `LightningCLI` subcommand exclusion list needs to be updated assert e in parameters
def test_lightning_cli_args_cluster_environments(tmpdir): plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")] class TestModel(BoringModel): def on_fit_start(self): # Ensure SLURMEnvironment is set, instead of default LightningEnvironment assert isinstance(self.trainer.accelerator_connector._cluster_environment, SLURMEnvironment) self.trainer.ran_asserts = True with mock.patch("sys.argv", ["any.py", f"--trainer.plugins={json.dumps(plugins)}"]): cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) assert cli.trainer.ran_asserts
def cli_main(): cli = LightningCLI( LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False, # used to de-activate automatic fitting. trainer_defaults={"callbacks": ImageSampler(), "max_epochs": 10}, ) cli.trainer.fit(cli.model, datamodule=cli.datamodule) cli.trainer.test(ckpt_path="best") predictions = cli.trainer.predict(ckpt_path="best") print(predictions[0])