def test_gpu_override(base_parser): """Test CLI ngpus argument can override what is in the config.""" args = base_parser.parse_args( '--config examples/mnist/mnist_fp.yaml --ngpus 8'.split(' ')) config = parse_config(args) assert config['environment']['ngpus'] == 8
def test_standard_args(base_parser): """Test parsing standard arguments.""" args = base_parser.parse_args( '--config examples/mnist/mnist_fp.yaml'.split(' ')) config = parse_config(args) assert isinstance(config['experiment_name'], str) and len( config['experiment_name']) assert config['environment']['platform'] == 'local' assert config['environment']['ngpus'] == (1 if torch.cuda.is_available() else 0) assert 'init_from_checkpoint' not in config assert 'restore_experiment' not in config assert not config['skip_training']
# # For licensing see accompanying LICENSE file. # Copyright (C) 2020 Apple Inc. All Rights Reserved. # """Driver script for running MNIST.""" from quant.common.compute_platform import LocalComputePlatform from quant.common.experiment import Experiment from quant.common.parser import get_base_argument_parser, parse_config from quant.common.tasks import classification_task from quant.data.data_loaders import MNISTDataLoader from quant.utils.visualization import get_tensorboard_hooks if __name__ == '__main__': parser = get_base_argument_parser('Driver script for running MNIST.') args = parser.parse_args() config = parse_config(args) platform = LocalComputePlatform(config['log'].get('root_experiments_dir', '.')) experiment = Experiment(classification_task, config, MNISTDataLoader, get_tensorboard_hooks) platform.run(experiment)
def test_missing_config(base_parser): """Test missing config.""" args = base_parser.parse_args([]) with pytest.raises(ValueError): parse_config(args)