def test_munchify_works():
    def main(cfg, _):
        print(cfg.attr)

    dir_path = os.path.dirname(os.path.realpath(__file__))
    config_path = os.path.join(dir_path, 'config2.yml')

    # https://stackoverflow.com/a/37343818
    with mock.patch('argparse.ArgumentParser.parse_args',
                    return_value=argparse.Namespace(config=config_path,
                                                    level='info')):
        run(main, __file__)
def test_run_config(config_name, level, called, main):
    dir_path = os.path.dirname(os.path.realpath(__file__))
    config_path = os.path.join(dir_path, config_name)

    # https://stackoverflow.com/a/37343818
    with mock.patch('argparse.ArgumentParser.parse_args',
                    return_value=argparse.Namespace(config=config_path,
                                                    level='info')):
        # https://stackoverflow.com/a/31756485
        with mock.patch(f'logging.Logger.{level}') as mock_logger:
            run(main, __file__)
            mock_logger.assert_called_with(called)
def test_if_name_not_main_then_not_called():
    '''
    Call `run` on a method which throws an exception, after patching
    `parse_args` to return None. This verifies no argument parsing is done, and
    that `main` is not called.
    '''
    def main(cfg, log):
        raise Exception('I should not be called')

    # https://stackoverflow.com/a/534847
    file = str(uuid.uuid4())
    with mock.patch('argparse.ArgumentParser.parse_args', return_value=None):
        run(main, file, '__not_main__')
def test_test_set_logger_levels_from_config_file(file, levels):
    def main(cfg, _):
        pass

    dir_path = os.path.dirname(os.path.realpath(__file__))
    config_path = os.path.join(dir_path, file)

    # https://stackoverflow.com/a/37343818
    with mock.patch('argparse.ArgumentParser.parse_args',
                    return_value=argparse.Namespace(config=config_path,
                                                    level='info')):
        run(main, __file__)

    for k, v in levels.items():
        assert v == logging.getLogger(k).level
Beispiel #5
0
from azureml.core.run import Run
import model
from pyconfigurableml.entry import run
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger


# TODO: long-term, get this from pytorch-lightning.
from _azureml import AzureMlLogger


def main(config, log):
    loggers = [
        AzureMlLogger(),
        TensorBoardLogger('lightning_logs')
    ]

    trainer = pl.Trainer(logger=loggers, **config.trainer)
    net = model.Mnist(**config.model)
    trainer.fit(net)


run(main, __file__, __name__)
def test_run_type_checking(input) -> None:
    with pytest.raises(TypeError):
        run(input, __file__)