def test_default_args(mock_argparse, tmpdir): """Tests default argument parser for Trainer""" mock_argparse.return_value = Namespace(**Trainer.default_attributes()) parser = LightningArgumentParser(add_help=False, parse_as_dict=False) args = parser.parse_args([]) args.max_epochs = 5 trainer = Trainer.from_argparse_args(args) assert isinstance(trainer, Trainer) assert trainer.max_epochs == 5
def test_default_args(mock_argparse, tmpdir): """Tests default argument parser for Trainer""" mock_argparse.return_value = Namespace(**Trainer.default_attributes()) # logger file to get meta logger = tutils.get_default_logger(tmpdir) parser = ArgumentParser(add_help=False) args = parser.parse_args() args.logger = logger args.max_epochs = 5 trainer = Trainer.from_argparse_args(args) assert isinstance(trainer, Trainer) assert trainer.max_epochs == 5
import inspect import pickle import sys from argparse import ArgumentParser, Namespace from unittest import mock import pytest import torch import tests.base.utils as tutils from pytorch_lightning import Trainer @mock.patch('argparse.ArgumentParser.parse_args', return_value=Namespace(**Trainer.default_attributes())) def test_default_args(tmpdir): """Tests default argument parser for Trainer""" # logger file to get meta logger = tutils.get_default_logger(tmpdir) parser = ArgumentParser(add_help=False) args = parser.parse_args() args.logger = logger args.max_epochs = 5 trainer = Trainer.from_argparse_args(args) assert isinstance(trainer, Trainer) assert trainer.max_epochs == 5
# Misconfig when neither test_step or test_end is implemented with pytest.raises(MisconfigurationException): model = LocalModelNoStep(hparams) Trainer().test(model) # No exceptions when one or both of test_step or test_end are implemented model = LocalModelNoEnd(hparams) Trainer().test(model) model = LightningTestModel(hparams) Trainer().test(model) @mock.patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(**Trainer.default_attributes())) def test_default_args(tmpdir): """Tests default argument parser for Trainer""" tutils.reset_seed() # logger file to get meta logger = tutils.get_test_tube_logger(tmpdir, False) parser = argparse.ArgumentParser(add_help=False) args = parser.parse_args() args.logger = logger args.max_epochs = 5 trainer = Trainer.from_argparse_args(args) assert isinstance(trainer, Trainer)