コード例 #1
0
def update_training_service_config(config, training_service, config_file_path):
    it_ts_config = get_yml_content(os.path.join('config', 'training_service.yml'))

    # hack for kubeflow trial config
    if training_service == 'kubeflow':
        it_ts_config[training_service]['trial']['worker']['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')

    if training_service == 'frameworkcontroller':
        it_ts_config[training_service]['trial']['taskRoles'][0]['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')
    
    if training_service == 'adl':
        # hack for adl trial config, codeDir in adl mode refers to path in container
        containerCodeDir = config['trial']['codeDir']
        # replace metric test folders to container folder
        if config['trial']['codeDir'] == '.':
            containerCodeDir = '/' + config_file_path[:config_file_path.rfind('/')]
        elif config['trial']['codeDir'] == '../naive_trial':
            containerCodeDir = '/test/config/naive_trial'
        elif '../../../' in config['trial']['codeDir']:
            # replace example folders to container folder
            containerCodeDir = config['trial']['codeDir'].replace('../../../', '/')
        it_ts_config[training_service]['trial']['codeDir'] = containerCodeDir
        it_ts_config[training_service]['trial']['command'] = 'cd {0} && {1}'.format(containerCodeDir, config['trial']['command'])

    deep_update(config, it_ts_config['all'])
    deep_update(config, it_ts_config[training_service])
コード例 #2
0
ファイル: run_tests.py プロジェクト: microsoft/nni
def prepare_config_file(test_case_config, it_config, args):
    config_path = args.nni_source_dir + test_case_config['configFile']
    test_yml_config = get_yml_content(config_path)

    # apply test case specific config
    if test_case_config.get('config') is not None:
        deep_update(test_yml_config, test_case_config['config'])

    # hack for windows
    if sys.platform == 'win32' and args.ts == 'local':
        test_yml_config['trial']['command'] = test_yml_config['trial'][
            'command'].replace('python3', 'python')

    # apply training service config
    # user's gpuNum, logCollection config is overwritten by the config in training_service.yml
    # the hack for kubeflow should be applied at last step
    update_training_service_config(test_yml_config, args.ts,
                                   test_case_config['configFile'],
                                   args.nni_source_dir, args.reuse_mode)

    # generate temporary config yml file to launch experiment
    new_config_file = config_path + '.tmp'
    dump_yml_content(new_config_file, test_yml_config)
    print(yaml.safe_dump(test_yml_config, default_flow_style=False),
          flush=True)

    return new_config_file
コード例 #3
0
def update_training_service_config(args):
    config = get_yml_content(TRAINING_SERVICE_FILE)
    if args.nni_manager_ip is not None:
        config[args.ts]['nniManagerIp'] = args.nni_manager_ip
    if args.ts == 'pai':
        if args.pai_user is not None:
            config[args.ts]['paiConfig']['userName'] = args.pai_user
        if args.pai_pwd is not None:
            config[args.ts]['paiConfig']['passWord'] = args.pai_pwd
        if args.pai_host is not None:
            config[args.ts]['paiConfig']['host'] = args.pai_host
        if args.nni_docker_image is not None:
            config[args.ts]['trial']['image'] = args.nni_docker_image
        if args.data_dir is not None:
            config[args.ts]['trial']['dataDir'] = args.data_dir
        if args.output_dir is not None:
            config[args.ts]['trial']['outputDir'] = args.output_dir
    elif args.ts == 'kubeflow':
        if args.nfs_server is not None:
            config[args.ts]['kubeflowConfig']['nfs']['server'] = args.nfs_server
        if args.nfs_path is not None:
            config[args.ts]['kubeflowConfig']['nfs']['path'] = args.nfs_path
        if args.keyvault_vaultname is not None:
            config[args.ts]['kubeflowConfig']['keyVault']['vaultName'] = args.keyvault_vaultname
        if args.keyvault_name is not None:
            config[args.ts]['kubeflowConfig']['keyVault']['name'] = args.keyvault_name
        if args.azs_account is not None:
            config[args.ts]['kubeflowConfig']['azureStorage']['accountName'] = args.azs_account
        if args.azs_share is not None:
            config[args.ts]['kubeflowConfig']['azureStorage']['azureShare'] = args.azs_share
        if args.nni_docker_image is not None:
            config[args.ts]['trial']['worker']['image'] = args.nni_docker_image

    dump_yml_content(TRAINING_SERVICE_FILE, config)
コード例 #4
0
def gen_new_config(config_file, training_service='local'):
    ''' 
    Generates temporary config file for integration test, the file
    should be deleted after testing.
    '''
    config = get_yml_content(config_file)
    new_config_file = config_file + '.tmp'

    ts = get_yml_content('training_service.yml')[training_service]
    print(config)
    print(ts)
    config.update(ts)
    print(config)
    dump_yml_content(new_config_file, config)

    return new_config_file, config
コード例 #5
0
def run(args):
    it_config = get_yml_content(args.config)

    for test_case_config in it_config['testCases']:
        name = test_case_config['name']
        if case_excluded(name, args.exclude):
            print('{} excluded'.format(name))
            continue
        if args.cases and not case_included(name, args.cases):
            continue

        # fill test case default config
        for k in it_config['defaultTestCaseConfig']:
            if k not in test_case_config:
                test_case_config[k] = it_config['defaultTestCaseConfig'][k]
        print(json.dumps(test_case_config, indent=4))

        if not match_platform(test_case_config):
            print('skipped {}, platform {} not match [{}]'.format(
                name, sys.platform, test_case_config['platform']))
            continue

        if not match_training_service(test_case_config, args.ts):
            print('skipped {}, training service {} not match [{}]'.format(
                name, args.ts, test_case_config['trainingService']))
            continue

        wait_for_port_available(8080, 30)
        print('{}Testing: {}{}'.format(GREEN, name, CLEAR))
        begin_time = time.time()

        run_test_case(test_case_config, it_config, args)
        print('{}Test {}: TEST PASS IN {} SECONDS{}'.format(
            GREEN, name, int(time.time() - begin_time), CLEAR),
              flush=True)
コード例 #6
0
ファイル: run_tests.py プロジェクト: yinfupai/nni
def run(args):
    it_config = get_yml_content(args.config)
    test_cases = it_config['testCases']

    for test_case_id, test_case_config in enumerate(test_cases, start=1):
        name = test_case_config['name']
        print(GREEN + '=' * 80 + CLEAR)
        print('## {}Testing: {}{} ##'.format(GREEN, name, CLEAR))

        # Print progress on devops
        print(
            f'##vso[task.setprogress value={int(test_case_id / len(test_cases) * 100)};]{name}'
        )

        if case_excluded(name, args.exclude):
            print('{} excluded'.format(name))
            continue
        if args.cases and not case_included(name, args.cases):
            continue

        # fill test case default config
        for k in it_config['defaultTestCaseConfig']:
            if k not in test_case_config:
                test_case_config[k] = it_config['defaultTestCaseConfig'][k]
        print(json.dumps(test_case_config, indent=4))

        if not match_platform(test_case_config):
            print('skipped {}, platform {} not match [{}]'.format(
                name, sys.platform, test_case_config['platform']))
            continue

        if not match_training_service(test_case_config, args.ts):
            print('skipped {}, training service {} not match [{}]'.format(
                name, args.ts, test_case_config['trainingService']))
            continue

        if args.ts == 'remote':
            if not match_remoteConfig(test_case_config, args.nni_source_dir):
                print('skipped {}, remoteConfig not match.'.format(name))
                continue

        # remote mode need more time to cleanup
        if args.ts == 'remote' or args.ts == 'hybrid':
            wait_for_port_available(8080, 240)
            wait_for_port_available(
                8081, 240
            )  # some training services need one more port to listen metrics

        # adl mode need more time to cleanup PVC
        if args.ts == 'adl' and name == 'nnictl-resume-2':
            time.sleep(30)

        begin_time = time.time()

        run_test_case(test_case_config, it_config, args)
        print('{}Test {}: TEST PASS IN {} SECONDS{}\n\n'.format(
            GREEN, name, int(time.time() - begin_time), CLEAR),
              flush=True)
コード例 #7
0
ファイル: run_tests.py プロジェクト: microsoft/nni
def get_max_values(config_file):
    experiment_config = get_yml_content(config_file)
    if experiment_config.get('maxExecDuration'):
        return parse_max_duration_time(experiment_config['maxExecDuration']
                                       ), experiment_config['maxTrialNum']
    else:
        return parse_max_duration_time(
            experiment_config['maxExperimentDuration']
        ), experiment_config['maxTrialNumber']
コード例 #8
0
def convert_command():
    '''convert command by platform'''
    if sys.platform != 'win32':
        return None
    config_files = glob.glob('./**/*.yml') + glob.glob('./**/**/*.yml')
    for config_file in config_files:
        print('processing {}'.format(config_file))
        yml_content = get_yml_content(config_file)
        if yml_content.get('trial'):
            if yml_content['trial'].get('command'):
                yml_content['trial']['command'] = yml_content['trial']['command'].replace('python3', 'python')
                dump_yml_content(config_file, yml_content)
コード例 #9
0
def gen_new_config(config_file, training_service='local'):
    '''
    Generates temporary config file for integration test, the file
    should be deleted after testing.
    '''
    config = get_yml_content(config_file)
    new_config_file = config_file + '.tmp'

    ts = get_yml_content('training_service.yml')[training_service]
    print(ts)

    # hack for kubeflow trial config
    if training_service == 'kubeflow':
        ts['trial']['worker']['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')

    deep_update(config, ts)
    print(config)
    dump_yml_content(new_config_file, config)

    return new_config_file, config
コード例 #10
0
ファイル: run_tests.py プロジェクト: microsoft/nni
def run(args):
    it_config = get_yml_content(args.config)

    for test_case_config in it_config['testCases']:
        name = test_case_config['name']
        if case_excluded(name, args.exclude):
            print('{} excluded'.format(name))
            continue
        if args.cases and not case_included(name, args.cases):
            continue

        # fill test case default config
        for k in it_config['defaultTestCaseConfig']:
            if k not in test_case_config:
                test_case_config[k] = it_config['defaultTestCaseConfig'][k]
        print(json.dumps(test_case_config, indent=4))

        if not match_platform(test_case_config):
            print('skipped {}, platform {} not match [{}]'.format(
                name, sys.platform, test_case_config['platform']))
            continue

        if not match_training_service(test_case_config, args.ts):
            print('skipped {}, training service {} not match [{}]'.format(
                name, args.ts, test_case_config['trainingService']))
            continue

        # remote mode need more time to cleanup
        if args.ts == 'remote' or args.ts == 'hybrid':
            if args.ts == 'remote':
                if not match_remoteConfig(test_case_config,
                                          args.nni_source_dir):
                    print('skipped {}, remoteConfig not match.'.format(name))
                    continue
            wait_for_port_available(8080, 240)
        else:
            wait_for_port_available(8080, 60)

        # adl mode need more time to cleanup PVC
        if args.ts == 'adl' and name == 'nnictl-resume-2':
            time.sleep(30)
        print('## {}Testing: {}{} ##'.format(GREEN, name, CLEAR))
        begin_time = time.time()

        run_test_case(test_case_config, it_config, args)
        print('{}Test {}: TEST PASS IN {} SECONDS{}'.format(
            GREEN, name, int(time.time() - begin_time), CLEAR),
              flush=True)
コード例 #11
0
ファイル: tuner_test.py プロジェクト: simbazad/nni
def switch(dispatch_type, dispatch_name):
    '''Change dispatch in config.yml'''
    config_path = 'tuner_test/local.yml'
    experiment_config = get_yml_content(config_path)
    if dispatch_name in ['GridSearch', 'BatchTuner']:
        experiment_config[dispatch_type.lower()] = {
            'builtin' + dispatch_type + 'Name': dispatch_name
        }
    else:
        experiment_config[dispatch_type.lower()] = {
            'builtin' + dispatch_type + 'Name': dispatch_name,
            'classArgs': {
                'optimize_mode': 'maximize'
            }
        }
    dump_yml_content(config_path, experiment_config)
コード例 #12
0
ファイル: run_tests.py プロジェクト: zhyj3038/nni
def update_training_service_config(config, training_service):
    it_ts_config = get_yml_content(os.path.join('config', 'training_service.yml'))

    # hack for kubeflow trial config
    if training_service == 'kubeflow':
        it_ts_config[training_service]['trial']['worker']['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')

    if training_service == 'frameworkcontroller':
        it_ts_config[training_service]['trial']['taskRoles'][0]['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')

    deep_update(config, it_ts_config['all'])
    deep_update(config, it_ts_config[training_service])
コード例 #13
0
ファイル: tuner_test.py プロジェクト: wtzl12519/nni
def switch(dispatch_type, dispatch_name):
    '''Change dispatch in config.yml'''
    config_path = get_config_file_path()
    experiment_config = get_yml_content(config_path)
    if dispatch_name in ['GridSearch', 'BatchTuner', 'Random']:
        experiment_config[dispatch_type.lower()] = {
            'builtin' + dispatch_type + 'Name': dispatch_name
        }
    else:
        experiment_config[dispatch_type.lower()] = {
            'builtin' + dispatch_type + 'Name': dispatch_name,
            'classArgs': {
                'optimize_mode': 'maximize'
            }
        }
    if dispatch_name == 'BatchTuner':
        experiment_config['searchSpacePath'] = 'batchtuner_search_space.json'
    else:
        experiment_config['searchSpacePath'] = 'search_space.json'
    dump_yml_content(config_path, experiment_config)
コード例 #14
0
ファイル: generate_ts_config.py プロジェクト: un-knight/nni
def update_training_service_config(args):
    config = get_yml_content(TRAINING_SERVICE_FILE)
    if args.nni_manager_ip is not None:
        config[args.ts]['nniManagerIp'] = args.nni_manager_ip
    if args.ts == 'pai':
        if args.pai_user is not None:
            config[args.ts]['paiConfig']['userName'] = args.pai_user
        if args.pai_host is not None:
            config[args.ts]['paiConfig']['host'] = args.pai_host
        if args.pai_token is not None:
            config[args.ts]['paiConfig']['token'] = args.pai_token
        if args.pai_reuse is not None:
            config[args.ts]['paiConfig']['reuse'] = args.pai_reuse.lower(
            ) == 'true'
        if args.nni_docker_image is not None:
            config[args.ts]['trial']['image'] = args.nni_docker_image
        if args.nni_manager_nfs_mount_path is not None:
            config[args.ts]['trial'][
                'nniManagerNFSMountPath'] = args.nni_manager_nfs_mount_path
        if args.container_nfs_mount_path is not None:
            config[args.ts]['trial'][
                'containerNFSMountPath'] = args.container_nfs_mount_path
        if args.pai_storage_config_name is not None:
            config[args.ts]['trial'][
                'paiStorageConfigName'] = args.pai_storage_config_name
        if args.vc is not None:
            config[args.ts]['trial']['virtualCluster'] = args.vc
    elif args.ts == 'kubeflow':
        if args.nfs_server is not None:
            config[
                args.ts]['kubeflowConfig']['nfs']['server'] = args.nfs_server
        if args.nfs_path is not None:
            config[args.ts]['kubeflowConfig']['nfs']['path'] = args.nfs_path
        if args.keyvault_vaultname is not None:
            config[args.ts]['kubeflowConfig']['keyVault'][
                'vaultName'] = args.keyvault_vaultname
        if args.keyvault_name is not None:
            config[args.ts]['kubeflowConfig']['keyVault'][
                'name'] = args.keyvault_name
        if args.azs_account is not None:
            config[args.ts]['kubeflowConfig']['azureStorage'][
                'accountName'] = args.azs_account
        if args.azs_share is not None:
            config[args.ts]['kubeflowConfig']['azureStorage'][
                'azureShare'] = args.azs_share
        if args.nni_docker_image is not None:
            config[args.ts]['trial']['worker']['image'] = args.nni_docker_image
    elif args.ts == 'frameworkcontroller':
        if args.nfs_server is not None:
            config[args.ts]['frameworkcontrollerConfig']['nfs'][
                'server'] = args.nfs_server
        if args.nfs_path is not None:
            config[args.ts]['frameworkcontrollerConfig']['nfs'][
                'path'] = args.nfs_path
        if args.keyvault_vaultname is not None:
            config[args.ts]['frameworkcontrollerConfig']['keyVault'][
                'vaultName'] = args.keyvault_vaultname
        if args.keyvault_name is not None:
            config[args.ts]['frameworkcontrollerConfig']['keyVault'][
                'name'] = args.keyvault_name
        if args.azs_account is not None:
            config[args.ts]['frameworkcontrollerConfig']['azureStorage'][
                'accountName'] = args.azs_account
        if args.azs_share is not None:
            config[args.ts]['frameworkcontrollerConfig']['azureStorage'][
                'azureShare'] = args.azs_share
        if args.nni_docker_image is not None:
            config[args.ts]['trial']['taskRoles'][0][
                'image'] = args.nni_docker_image
    elif args.ts == 'remote':
        if args.remote_user is not None:
            config[args.ts]['machineList'][0]['username'] = args.remote_user
        if args.remote_host is not None:
            config[args.ts]['machineList'][0]['ip'] = args.remote_host
        if args.remote_port is not None:
            config[args.ts]['machineList'][0]['port'] = args.remote_port
        if args.remote_pwd is not None:
            config[args.ts]['machineList'][0]['passwd'] = args.remote_pwd
        if args.remote_reuse is not None:
            config[args.ts]['remoteConfig']['reuse'] = args.remote_reuse.lower(
            ) == 'true'
    elif args.ts == 'adl':
        if args.nni_docker_image is not None:
            config[args.ts]['trial']['image'] = args.nni_docker_image
        if args.checkpoint_storage_class is not None:
            config[args.ts]['trial']['checkpoint'][
                'storageClass'] = args.checkpoint_storage_class
        if args.checkpoint_storage_size is not None:
            config[args.ts]['trial']['checkpoint'][
                'storageSize'] = args.checkpoint_storage_size
        if args.adaptive is not None:
            config[args.ts]['trial']['adaptive'] = args.adaptive
        if args.adl_nfs_server is not None and args.adl_nfs_path is not None and args.adl_nfs_container_mount_path is not None:
            # default keys in nfs is empty, need to initialize
            config[args.ts]['trial']['nfs'] = {}
            config[args.ts]['trial']['nfs']['server'] = args.adl_nfs_server
            config[args.ts]['trial']['nfs']['path'] = args.adl_nfs_path
            config[args.ts]['trial']['nfs'][
                'container_mount_path'] = args.nadl_fs_container_mount_path
    elif args.ts == 'aml':
        if args.nni_docker_image is not None:
            config[args.ts]['trial']['image'] = args.nni_docker_image
        if args.subscription_id is not None:
            config[
                args.ts]['amlConfig']['subscriptionId'] = args.subscription_id
        if args.resource_group is not None:
            config[args.ts]['amlConfig']['resourceGroup'] = args.resource_group
        if args.workspace_name is not None:
            config[args.ts]['amlConfig']['workspaceName'] = args.workspace_name
        if args.compute_target is not None:
            config[args.ts]['amlConfig']['computeTarget'] = args.compute_target

    dump_yml_content(TRAINING_SERVICE_FILE, config)
コード例 #15
0
def update_training_service_config(args):
    config = get_yml_content(TRAINING_SERVICE_FILE)
    if args.nni_manager_ip is not None:
        config[args.ts]['nniManagerIp'] = args.nni_manager_ip
    if args.ts == 'paiYarn':
        if args.pai_user is not None:
            config[args.ts]['paiYarnConfig']['userName'] = args.pai_user
        if args.pai_pwd is not None:
            config[args.ts]['paiYarnConfig']['passWord'] = args.pai_pwd
        if args.pai_host is not None:
            config[args.ts]['paiYarnConfig']['host'] = args.pai_host
        if args.nni_docker_image is not None:
            config[args.ts]['trial']['image'] = args.nni_docker_image
        if args.data_dir is not None:
            config[args.ts]['trial']['dataDir'] = args.data_dir
        if args.output_dir is not None:
            config[args.ts]['trial']['outputDir'] = args.output_dir
        if args.vc is not None:
            config[args.ts]['trial']['virtualCluster'] = args.vc
    if args.ts == 'pai':
        if args.pai_user is not None:
            config[args.ts]['paiConfig']['userName'] = args.pai_user
        if args.pai_host is not None:
            config[args.ts]['paiConfig']['host'] = args.pai_host
        if args.pai_token is not None:
            config[args.ts]['paiConfig']['token'] = args.pai_token
        if args.pai_reuse is not None:
            config[args.ts]['paiConfig']['reuse'] = args.pai_reuse.lower(
            ) == 'true'
        if args.nni_docker_image is not None:
            config[args.ts]['trial']['image'] = args.nni_docker_image
        if args.nni_manager_nfs_mount_path is not None:
            config[args.ts]['trial'][
                'nniManagerNFSMountPath'] = args.nni_manager_nfs_mount_path
        if args.container_nfs_mount_path is not None:
            config[args.ts]['trial'][
                'containerNFSMountPath'] = args.container_nfs_mount_path
        if args.pai_storage_config_name is not None:
            config[args.ts]['trial'][
                'paiStorageConfigName'] = args.pai_storage_config_name
        if args.vc is not None:
            config[args.ts]['trial']['virtualCluster'] = args.vc
    elif args.ts == 'kubeflow':
        if args.nfs_server is not None:
            config[
                args.ts]['kubeflowConfig']['nfs']['server'] = args.nfs_server
        if args.nfs_path is not None:
            config[args.ts]['kubeflowConfig']['nfs']['path'] = args.nfs_path
        if args.keyvault_vaultname is not None:
            config[args.ts]['kubeflowConfig']['keyVault'][
                'vaultName'] = args.keyvault_vaultname
        if args.keyvault_name is not None:
            config[args.ts]['kubeflowConfig']['keyVault'][
                'name'] = args.keyvault_name
        if args.azs_account is not None:
            config[args.ts]['kubeflowConfig']['azureStorage'][
                'accountName'] = args.azs_account
        if args.azs_share is not None:
            config[args.ts]['kubeflowConfig']['azureStorage'][
                'azureShare'] = args.azs_share
        if args.nni_docker_image is not None:
            config[args.ts]['trial']['worker']['image'] = args.nni_docker_image
    elif args.ts == 'frameworkcontroller':
        if args.nfs_server is not None:
            config[args.ts]['frameworkcontrollerConfig']['nfs'][
                'server'] = args.nfs_server
        if args.nfs_path is not None:
            config[args.ts]['frameworkcontrollerConfig']['nfs'][
                'path'] = args.nfs_path
        if args.keyvault_vaultname is not None:
            config[args.ts]['frameworkcontrollerConfig']['keyVault'][
                'vaultName'] = args.keyvault_vaultname
        if args.keyvault_name is not None:
            config[args.ts]['frameworkcontrollerConfig']['keyVault'][
                'name'] = args.keyvault_name
        if args.azs_account is not None:
            config[args.ts]['frameworkcontrollerConfig']['azureStorage'][
                'accountName'] = args.azs_account
        if args.azs_share is not None:
            config[args.ts]['frameworkcontrollerConfig']['azureStorage'][
                'azureShare'] = args.azs_share
        if args.nni_docker_image is not None:
            config[args.ts]['trial']['taskRoles'][0][
                'image'] = args.nni_docker_image
    elif args.ts == 'remote':
        if args.remote_user is not None:
            config[args.ts]['machineList'][0]['username'] = args.remote_user
        if args.remote_host is not None:
            config[args.ts]['machineList'][0]['ip'] = args.remote_host
        if args.remote_port is not None:
            config[args.ts]['machineList'][0]['port'] = args.remote_port
        if args.remote_pwd is not None:
            config[args.ts]['machineList'][0]['passwd'] = args.remote_pwd
        if args.remote_reuse is not None:
            config[args.ts]['remoteConfig']['reuse'] = args.remote_reuse.lower(
            ) == 'true'

    dump_yml_content(TRAINING_SERVICE_FILE, config)
コード例 #16
0
ファイル: run_tests.py プロジェクト: zhyj3038/nni
def get_max_values(config_file):
    experiment_config = get_yml_content(config_file)
    return parse_max_duration_time(experiment_config['maxExecDuration']), experiment_config['maxTrialNum']
コード例 #17
0
ファイル: run_tests.py プロジェクト: microsoft/nni
def update_training_service_config(config,
                                   training_service,
                                   config_file_path,
                                   nni_source_dir,
                                   reuse_mode='False'):
    it_ts_config = get_yml_content(
        os.path.join('config', 'training_service.yml'))
    # hack for kubeflow trial config
    if training_service == 'kubeflow' and reuse_mode == 'False':
        it_ts_config[training_service]['trial']['worker']['command'] = config[
            'trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')
    elif training_service == 'kubeflow' and reuse_mode == 'True':
        it_ts_config = get_yml_content(
            os.path.join('config', 'training_service_v2.yml'))
        print(it_ts_config)
        it_ts_config[training_service]['trainingService']['worker'][
            'command'] = config['trialCommand']
        it_ts_config[training_service]['trainingService']['worker'][
            'code_directory'] = config['trialCodeDirectory']

    if training_service == 'frameworkcontroller' and reuse_mode == 'False':
        it_ts_config[training_service]['trial']['taskRoles'][0][
            'command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')
    elif training_service == 'frameworkcontroller' and reuse_mode == 'True':
        it_ts_config = get_yml_content(
            os.path.join('config', 'training_service_v2.yml'))
        it_ts_config[training_service]['trainingService']['taskRoles'][0][
            'command'] = config['trialCommand']

    if training_service == 'adl':
        # hack for adl trial config, codeDir in adl mode refers to path in container
        containerCodeDir = config['trial']['codeDir']
        # replace metric test folders to container folder
        if config['trial']['codeDir'] == '.':
            containerCodeDir = '/' + config_file_path[:config_file_path.
                                                      rfind('/')]
        elif config['trial']['codeDir'] == '../naive_trial':
            containerCodeDir = '/test/config/naive_trial'
        elif '../../../' in config['trial']['codeDir']:
            # replace example folders to container folder
            containerCodeDir = config['trial']['codeDir'].replace(
                '../../../', '/')
        it_ts_config[training_service]['trial']['codeDir'] = containerCodeDir
        it_ts_config[training_service]['trial'][
            'command'] = 'cd {0} && {1}'.format(containerCodeDir,
                                                config['trial']['command'])

    if training_service == 'remote':
        testcase_config = get_yml_content(nni_source_dir + config_file_path)
        sharedStorage = testcase_config.get('sharedStorage')
        if sharedStorage is None:
            it_ts_config[training_service].pop('sharedStorage')
        elif str(sharedStorage.get('storageType')).lower() == 'nfs':
            it_ts_config[training_service].get('sharedStorage').pop(
                'storageAccountKey')
        elif str(sharedStorage.get('storageType')).lower() == 'azureblob':
            it_ts_config[training_service].get('sharedStorage').pop(
                'nfsServer')
            it_ts_config[training_service].get('sharedStorage').pop(
                'exportedDirectory')
        else:
            it_ts_config[training_service].pop('sharedStorage')

    if training_service == 'hybrid':
        it_ts_config = get_yml_content(
            os.path.join('config', 'training_service_v2.yml'))
    elif reuse_mode != 'True':
        deep_update(config, it_ts_config['all'])
    deep_update(config, it_ts_config[training_service])