def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" monkeypatch.setattr("torch.cuda.device_count", lambda: 2) cli_args = cli_args.split(' ') if cli_args else [] parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() trainer = Trainer.from_argparse_args(args) assert trainer.data_parallel_device_ids == expected_gpu
def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): """Test parsing complex types.""" parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() for k, v in expected.items(): assert getattr(args, k) == v if instantiate: assert Trainer.from_argparse_args(args)
def test_default_args(mock_argparse, tmpdir): """Tests default argument parser for Trainer""" mock_argparse.return_value = Namespace(**Trainer.default_attributes()) parser = LightningArgumentParser(add_help=False, parse_as_dict=False) args = parser.parse_args([]) args.max_epochs = 5 trainer = Trainer.from_argparse_args(args) assert isinstance(trainer, Trainer) assert trainer.max_epochs == 5
def test_parse_args_parsing(cli_args, expected): """Test parsing simple types and None optionals not modified.""" cli_args = cli_args.split(' ') if cli_args else [] parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() for k, v in expected.items(): assert getattr(args, k) == v if 'tpu_cores' not in expected or _TPU_AVAILABLE: assert Trainer.from_argparse_args(args)
def test_add_argparse_args_redefined(cli_args): """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness.""" parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) args = parser.parse_args(cli_args) # make sure we can pickle args pickle.dumps(args) # Check few deprecated args are not in namespace: for depr_name in ("gradient_clip", "nb_gpu_nodes", "max_nb_epochs"): assert depr_name not in args trainer = Trainer.from_argparse_args(args=args) pickle.dumps(trainer) assert isinstance(trainer, Trainer)
def test_add_argparse_args_redefined_error(cli_args, monkeypatch): """Asserts error raised in case of passing not default cli arguments.""" class _UnkArgError(Exception): pass def _raise(): raise _UnkArgError parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) monkeypatch.setattr(parser, "exit", lambda *args: _raise(), raising=True) with pytest.raises(_UnkArgError): parser.parse_args(cli_args)
import pytorch_lightning as pl from pytorch_lightning.utilities.cli import LightningArgumentParser from pytorch_lightning.loggers import TensorBoardLogger from argparse import Namespace import os from src.data import ContentStyleDataModule from src.model import ASTModel # Parse arguments parser = LightningArgumentParser() parser.add_lightning_class_args(pl.Trainer, None) parser.add_lightning_class_args(ASTModel, "model") parser.add_lightning_class_args(ContentStyleDataModule, "data") parser.add_argument("--output_path", type=str, help="Directory to save outputs.", default="output/") parser.add_argument("--experiment_name", type=str, help="Name of experiment.", default="default") args = parser.parse_args() # Define loggers tb_logger = TensorBoardLogger(save_dir=args["output_path"], name=args["experiment_name"]) # Setup model dm = ContentStyleDataModule(**args["data"]) model = ASTModel(**args["model"])