def test_munchify_works(): def main(cfg, _): print(cfg.attr) dir_path = os.path.dirname(os.path.realpath(__file__)) config_path = os.path.join(dir_path, 'config2.yml') # https://stackoverflow.com/a/37343818 with mock.patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(config=config_path, level='info')): run(main, __file__)
def test_run_config(config_name, level, called, main): dir_path = os.path.dirname(os.path.realpath(__file__)) config_path = os.path.join(dir_path, config_name) # https://stackoverflow.com/a/37343818 with mock.patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(config=config_path, level='info')): # https://stackoverflow.com/a/31756485 with mock.patch(f'logging.Logger.{level}') as mock_logger: run(main, __file__) mock_logger.assert_called_with(called)
def test_if_name_not_main_then_not_called(): ''' Call `run` on a method which throws an exception, after patching `parse_args` to return None. This verifies no argument parsing is done, and that `main` is not called. ''' def main(cfg, log): raise Exception('I should not be called') # https://stackoverflow.com/a/534847 file = str(uuid.uuid4()) with mock.patch('argparse.ArgumentParser.parse_args', return_value=None): run(main, file, '__not_main__')
def test_test_set_logger_levels_from_config_file(file, levels): def main(cfg, _): pass dir_path = os.path.dirname(os.path.realpath(__file__)) config_path = os.path.join(dir_path, file) # https://stackoverflow.com/a/37343818 with mock.patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(config=config_path, level='info')): run(main, __file__) for k, v in levels.items(): assert v == logging.getLogger(k).level
from azureml.core.run import Run import model from pyconfigurableml.entry import run import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger # TODO: long-term, get this from pytorch-lightning. from _azureml import AzureMlLogger def main(config, log): loggers = [ AzureMlLogger(), TensorBoardLogger('lightning_logs') ] trainer = pl.Trainer(logger=loggers, **config.trainer) net = model.Mnist(**config.model) trainer.fit(net) run(main, __file__, __name__)
def test_run_type_checking(input) -> None: with pytest.raises(TypeError): run(input, __file__)