コード例 #1
0
ファイル: test_cli.py プロジェクト: zihua/pytorch-lightning
def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu):
    """Test parsing of gpus and instantiation of Trainer."""
    monkeypatch.setattr("torch.cuda.device_count", lambda: 2)
    cli_args = cli_args.split(' ') if cli_args else []
    parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
    parser.add_lightning_class_args(Trainer, None)
    with mock.patch("sys.argv", ["any.py"] + cli_args):
        args = parser.parse_args()

    trainer = Trainer.from_argparse_args(args)
    assert trainer.data_parallel_device_ids == expected_gpu
コード例 #2
0
ファイル: test_cli.py プロジェクト: zihua/pytorch-lightning
def test_parse_args_parsing_complex_types(cli_args, expected, instantiate):
    """Test parsing complex types."""
    parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
    parser.add_lightning_class_args(Trainer, None)
    with mock.patch("sys.argv", ["any.py"] + cli_args):
        args = parser.parse_args()

    for k, v in expected.items():
        assert getattr(args, k) == v
    if instantiate:
        assert Trainer.from_argparse_args(args)
コード例 #3
0
ファイル: test_cli.py プロジェクト: zihua/pytorch-lightning
def test_default_args(mock_argparse, tmpdir):
    """Tests default argument parser for Trainer"""
    mock_argparse.return_value = Namespace(**Trainer.default_attributes())

    parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
    args = parser.parse_args([])

    args.max_epochs = 5
    trainer = Trainer.from_argparse_args(args)

    assert isinstance(trainer, Trainer)
    assert trainer.max_epochs == 5
コード例 #4
0
ファイル: test_cli.py プロジェクト: zihua/pytorch-lightning
def test_parse_args_parsing(cli_args, expected):
    """Test parsing simple types and None optionals not modified."""
    cli_args = cli_args.split(' ') if cli_args else []
    parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
    parser.add_lightning_class_args(Trainer, None)
    with mock.patch("sys.argv", ["any.py"] + cli_args):
        args = parser.parse_args()

    for k, v in expected.items():
        assert getattr(args, k) == v
    if 'tpu_cores' not in expected or _TPU_AVAILABLE:
        assert Trainer.from_argparse_args(args)
コード例 #5
0
def test_add_argparse_args_redefined(cli_args):
    """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
    parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
    parser.add_lightning_class_args(Trainer, None)

    args = parser.parse_args(cli_args)

    # make sure we can pickle args
    pickle.dumps(args)

    # Check few deprecated args are not in namespace:
    for depr_name in ("gradient_clip", "nb_gpu_nodes", "max_nb_epochs"):
        assert depr_name not in args

    trainer = Trainer.from_argparse_args(args=args)
    pickle.dumps(trainer)

    assert isinstance(trainer, Trainer)
コード例 #6
0
ファイル: test_cli.py プロジェクト: jspaezp/pytorch-lightning
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
    """Asserts error raised in case of passing not default cli arguments."""
    class _UnkArgError(Exception):
        pass

    def _raise():
        raise _UnkArgError

    parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
    parser.add_lightning_class_args(Trainer, None)

    monkeypatch.setattr(parser, "exit", lambda *args: _raise(), raising=True)

    with pytest.raises(_UnkArgError):
        parser.parse_args(cli_args)
コード例 #7
0
ファイル: train.py プロジェクト: bwconrad/AdaIn-StyleTransfer
import pytorch_lightning as pl
from pytorch_lightning.utilities.cli import LightningArgumentParser
from pytorch_lightning.loggers import TensorBoardLogger
from argparse import Namespace
import os

from src.data import ContentStyleDataModule
from src.model import ASTModel

# Parse arguments
parser = LightningArgumentParser()
parser.add_lightning_class_args(pl.Trainer, None)
parser.add_lightning_class_args(ASTModel, "model")
parser.add_lightning_class_args(ContentStyleDataModule, "data")
parser.add_argument("--output_path",
                    type=str,
                    help="Directory to save outputs.",
                    default="output/")
parser.add_argument("--experiment_name",
                    type=str,
                    help="Name of experiment.",
                    default="default")
args = parser.parse_args()

# Define loggers
tb_logger = TensorBoardLogger(save_dir=args["output_path"],
                              name=args["experiment_name"])

# Setup model
dm = ContentStyleDataModule(**args["data"])
model = ASTModel(**args["model"])