Esempio n. 1
0
def test_default_args(mock_argparse, tmpdir):
    """Tests default argument parser for Trainer"""
    mock_argparse.return_value = Namespace(**Trainer.default_attributes())

    parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
    args = parser.parse_args([])

    args.max_epochs = 5
    trainer = Trainer.from_argparse_args(args)

    assert isinstance(trainer, Trainer)
    assert trainer.max_epochs == 5
def test_default_args(mock_argparse, tmpdir):
    """Tests default argument parser for Trainer"""
    mock_argparse.return_value = Namespace(**Trainer.default_attributes())

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    parser = ArgumentParser(add_help=False)
    args = parser.parse_args()
    args.logger = logger

    args.max_epochs = 5
    trainer = Trainer.from_argparse_args(args)

    assert isinstance(trainer, Trainer)
    assert trainer.max_epochs == 5
Esempio n. 3
0
import inspect
import pickle
import sys
from argparse import ArgumentParser, Namespace
from unittest import mock

import pytest
import torch

import tests.base.utils as tutils
from pytorch_lightning import Trainer


@mock.patch('argparse.ArgumentParser.parse_args',
            return_value=Namespace(**Trainer.default_attributes()))
def test_default_args(tmpdir):
    """Tests default argument parser for Trainer"""

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    parser = ArgumentParser(add_help=False)
    args = parser.parse_args()
    args.logger = logger

    args.max_epochs = 5
    trainer = Trainer.from_argparse_args(args)

    assert isinstance(trainer, Trainer)
    assert trainer.max_epochs == 5
Esempio n. 4
0
    # Misconfig when neither test_step or test_end is implemented
    with pytest.raises(MisconfigurationException):
        model = LocalModelNoStep(hparams)
        Trainer().test(model)

    # No exceptions when one or both of test_step or test_end are implemented
    model = LocalModelNoEnd(hparams)
    Trainer().test(model)

    model = LightningTestModel(hparams)
    Trainer().test(model)


@mock.patch('argparse.ArgumentParser.parse_args',
            return_value=argparse.Namespace(**Trainer.default_attributes()))
def test_default_args(tmpdir):
    """Tests default argument parser for Trainer"""
    tutils.reset_seed()

    # logger file to get meta
    logger = tutils.get_test_tube_logger(tmpdir, False)

    parser = argparse.ArgumentParser(add_help=False)
    args = parser.parse_args()
    args.logger = logger

    args.max_epochs = 5
    trainer = Trainer.from_argparse_args(args)

    assert isinstance(trainer, Trainer)