예제 #1
0
def test_remote():
    config = ExperimentConfig(**minimal_json)
    assert config.json() == minimal_canon

    assert minimal_class.json() == minimal_canon

    config = ExperimentConfig(**detailed_json)
    assert config.json() == detailed_canon
예제 #2
0
파일: launcher.py 프로젝트: wang7393/nni
def _validate_v2(config, path):
    base_path = Path(path).parent
    try:
        conf = ExperimentConfig(_base_path=base_path, **config)
        return conf.json()
    except Exception as e:
        print_error(f'Config V2 validation failed: {repr(e)}')
예제 #3
0
파일: launcher.py 프로젝트: yimikai/nni
def create_experiment(args):
    '''start a new experiment'''
    experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
    config_path = os.path.abspath(args.config)
    if not os.path.exists(config_path):
        print_error('Please set correct config path!')
        exit(1)
    config_yml = get_yml_content(config_path)

    try:
        config = ExperimentConfig(_base_path=Path(config_path).parent, **config_yml)
        config_v2 = config.json()
    except Exception as error_v2:
        print_warning('Validation with V2 schema failed. Trying to convert from V1 format...')
        try:
            validate_all_content(config_yml, config_path)
        except Exception as error_v1:
            print_error(f'Convert from v1 format failed: {repr(error_v1)}')
            print_error(f'Config in v2 format validation failed: {repr(error_v2)}')
            exit(1)
        from nni.experiment.config import convert
        config_v2 = convert.to_v2(config_yml).json()

    try:
        if getattr(config_v2['trainingService'], 'platform', None) in k8s_training_services:
            launch_experiment(args, config_yml, 'new', experiment_id, 1)
        else:
            launch_experiment(args, config_v2, 'new', experiment_id, 2)
    except Exception as exception:
        restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
        if restServerPid:
            kill_command(restServerPid)
        print_error(exception)
        exit(1)
예제 #4
0
def test_all():
    minimal = ExperimentConfig(**minimal_json)
    assert minimal.json() == minimal_canon

    assert minimal_class.json() == minimal_canon_2

    detailed = ExperimentConfig.load(expand_path('assets/config.yaml'))
    assert detailed.json() == detailed_canon