コード例 #1
0
def launch_experiment(args,
                      experiment_config,
                      mode,
                      config_file_name,
                      experiment_id=None):
    '''follow steps to start rest server and start experiment'''
    nni_config = Config(config_file_name)
    # check packages for tuner
    package_name, module_name = None, None
    if experiment_config.get('tuner') and experiment_config['tuner'].get(
            'builtinTunerName'):
        package_name = experiment_config['tuner']['builtinTunerName']
        module_name = ModuleName.get(package_name)
    elif experiment_config.get('advisor') and experiment_config['advisor'].get(
            'builtinAdvisorName'):
        package_name = experiment_config['advisor']['builtinAdvisorName']
        module_name = AdvisorModuleName.get(package_name)
    if package_name and module_name:
        try:
            stdout_full_path, stderr_full_path = get_log_path(config_file_name)
            with open(stdout_full_path,
                      'a+') as stdout_file, open(stderr_full_path,
                                                 'a+') as stderr_file:
                check_call([sys.executable, '-c',
                            'import %s' % (module_name)],
                           stdout=stdout_file,
                           stderr=stderr_file)
        except CalledProcessError:
            print_error('some errors happen when import package %s.' %
                        (package_name))
            print_log_content(config_file_name)
            if package_name in PACKAGE_REQUIREMENTS:
                print_error('If %s is not installed, it should be installed through '\
                            '\'nnictl package install --name %s\''%(package_name, package_name))
            exit(1)
    log_dir = experiment_config['logDir'] if experiment_config.get(
        'logDir') else None
    log_level = experiment_config['logLevel'] if experiment_config.get(
        'logLevel') else None
    #view experiment mode do not need debug function, when view an experiment, there will be no new logs created
    if mode != 'view':
        if log_level not in [
                'trace', 'debug'
        ] and (args.debug or experiment_config.get('debug') is True):
            log_level = 'debug'
    # start rest server
    rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
                                                 mode, config_file_name, experiment_id, log_dir, log_level)
    nni_config.set_config('restServerPid', rest_process.pid)
    # Deal with annotation
    if experiment_config.get('useAnnotation'):
        path = os.path.join(tempfile.gettempdir(), get_user(), 'nni',
                            'annotation')
        if not os.path.isdir(path):
            os.makedirs(path)
        path = tempfile.mkdtemp(dir=path)
        nas_mode = experiment_config['trial'].get('nasMode', 'classic_mode')
        code_dir = expand_annotations(experiment_config['trial']['codeDir'],
                                      path,
                                      nas_mode=nas_mode)
        experiment_config['trial']['codeDir'] = code_dir
        search_space = generate_search_space(code_dir)
        experiment_config['searchSpace'] = json.dumps(search_space)
        assert search_space, ERROR_INFO % 'Generated search space is empty'
    elif experiment_config.get('searchSpacePath'):
        search_space = get_json_content(
            experiment_config.get('searchSpacePath'))
        experiment_config['searchSpace'] = json.dumps(search_space)
    else:
        experiment_config['searchSpace'] = json.dumps('')

    # check rest server
    running, _ = check_rest_server(args.port)
    if running:
        print_normal('Successfully started Restful server!')
    else:
        print_error('Restful server start failed!')
        print_log_content(config_file_name)
        try:
            kill_command(rest_process.pid)
        except Exception:
            raise Exception(ERROR_INFO % 'Rest server stopped!')
        exit(1)
    if mode != 'view':
        # set platform configuration
        set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\
                            config_file_name, rest_process)

    # start a new experiment
    print_normal('Starting experiment...')
    # set debug configuration
    if mode != 'view' and experiment_config.get('debug') is None:
        experiment_config['debug'] = args.debug
    response = set_experiment(experiment_config, mode, args.port,
                              config_file_name)
    if response:
        if experiment_id is None:
            experiment_id = json.loads(response.text).get('experiment_id')
        nni_config.set_config('experimentId', experiment_id)
    else:
        print_error('Start experiment failed!')
        print_log_content(config_file_name)
        try:
            kill_command(rest_process.pid)
        except Exception:
            raise Exception(ERROR_INFO % 'Restful server stopped!')
        exit(1)
    if experiment_config.get('nniManagerIp'):
        web_ui_url_list = [
            '{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))
        ]
    else:
        web_ui_url_list = get_local_urls(args.port)
    nni_config.set_config('webuiUrl', web_ui_url_list)

    #save experiment information
    nnictl_experiment_config = Experiments()
    nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name,\
                                            experiment_config['trainingServicePlatform'])

    print_normal(EXPERIMENT_SUCCESS_INFO %
                 (experiment_id, '   '.join(web_ui_url_list)))
コード例 #2
0
ファイル: launcher.py プロジェクト: sauravsrijan/nni
def launch_experiment(args,
                      experiment_config,
                      mode,
                      config_file_name,
                      experiment_id=None):
    '''follow steps to start rest server and start experiment'''
    nni_config = Config(config_file_name)
    # check execution policy in powershell
    if sys.platform == 'win32':
        execution_policy = check_output(
            ['powershell.exe', 'Get-ExecutionPolicy']).decode('ascii').strip()
        if execution_policy == 'Restricted':
            print_error('PowerShell execution policy error, please run PowerShell as administrator with this command first:\r\n'\
                + '\'Set-ExecutionPolicy -ExecutionPolicy Unrestricted\'')
            exit(1)
    # check packages for tuner
    package_name, module_name = None, None
    if experiment_config.get('tuner') and experiment_config['tuner'].get(
            'builtinTunerName'):
        package_name = experiment_config['tuner']['builtinTunerName']
        module_name = ModuleName.get(package_name)
    elif experiment_config.get('advisor') and experiment_config['advisor'].get(
            'builtinAdvisorName'):
        package_name = experiment_config['advisor']['builtinAdvisorName']
        module_name = AdvisorModuleName.get(package_name)
    if package_name and module_name:
        try:
            check_call([sys.executable, '-c',
                        'import %s' % (module_name)],
                       stdout=PIPE,
                       stderr=PIPE)
        except CalledProcessError as e:
            print_error(
                '%s should be installed through \'nnictl package install --name %s\''
                % (package_name, package_name))
            exit(1)
    log_dir = experiment_config['logDir'] if experiment_config.get(
        'logDir') else None
    log_level = experiment_config['logLevel'] if experiment_config.get(
        'logLevel') else None
    if log_level not in ['trace', 'debug'] and args.debug:
        log_level = 'debug'
    # start rest server
    rest_process, start_time = start_rest_server(
        args.port, experiment_config['trainingServicePlatform'], mode,
        config_file_name, experiment_id, log_dir, log_level)
    nni_config.set_config('restServerPid', rest_process.pid)
    # Deal with annotation
    if experiment_config.get('useAnnotation'):
        path = os.path.join(tempfile.gettempdir(), get_user(), 'nni',
                            'annotation')
        if not os.path.isdir(path):
            os.makedirs(path)
        path = tempfile.mkdtemp(dir=path)
        code_dir = expand_annotations(experiment_config['trial']['codeDir'],
                                      path)
        experiment_config['trial']['codeDir'] = code_dir
        search_space = generate_search_space(code_dir)
        experiment_config['searchSpace'] = json.dumps(search_space)
        assert search_space, ERROR_INFO % 'Generated search space is empty'
    elif experiment_config.get('searchSpacePath'):
        search_space = get_json_content(
            experiment_config.get('searchSpacePath'))
        experiment_config['searchSpace'] = json.dumps(search_space)
    else:
        experiment_config['searchSpace'] = json.dumps('')

    # check rest server
    running, _ = check_rest_server(args.port)
    if running:
        print_normal('Successfully started Restful server!')
    else:
        print_error('Restful server start failed!')
        print_log_content(config_file_name)
        try:
            kill_command(rest_process.pid)
        except Exception:
            raise Exception(ERROR_INFO % 'Rest server stopped!')
        exit(1)

    # set remote config
    if experiment_config['trainingServicePlatform'] == 'remote':
        print_normal('Setting remote config...')
        config_result, err_msg = set_remote_config(experiment_config,
                                                   args.port, config_file_name)
        if config_result:
            print_normal('Successfully set remote config!')
        else:
            print_error('Failed! Error is: {}'.format(err_msg))
            try:
                kill_command(rest_process.pid)
            except Exception:
                raise Exception(ERROR_INFO % 'Rest server stopped!')
            exit(1)

    # set local config
    if experiment_config['trainingServicePlatform'] == 'local':
        print_normal('Setting local config...')
        if set_local_config(experiment_config, args.port, config_file_name):
            print_normal('Successfully set local config!')
        else:
            print_error('Set local config failed!')
            try:
                kill_command(rest_process.pid)
            except Exception:
                raise Exception(ERROR_INFO % 'Rest server stopped!')
            exit(1)

    #set pai config
    if experiment_config['trainingServicePlatform'] == 'pai':
        print_normal('Setting pai config...')
        config_result, err_msg = set_pai_config(experiment_config, args.port,
                                                config_file_name)
        if config_result:
            print_normal('Successfully set pai config!')
        else:
            if err_msg:
                print_error('Failed! Error is: {}'.format(err_msg))
            try:
                kill_command(rest_process.pid)
            except Exception:
                raise Exception(ERROR_INFO % 'Restful server stopped!')
            exit(1)

    #set kubeflow config
    if experiment_config['trainingServicePlatform'] == 'kubeflow':
        print_normal('Setting kubeflow config...')
        config_result, err_msg = set_kubeflow_config(experiment_config,
                                                     args.port,
                                                     config_file_name)
        if config_result:
            print_normal('Successfully set kubeflow config!')
        else:
            if err_msg:
                print_error('Failed! Error is: {}'.format(err_msg))
            try:
                kill_command(rest_process.pid)
            except Exception:
                raise Exception(ERROR_INFO % 'Restful server stopped!')
            exit(1)

        #set kubeflow config
    if experiment_config['trainingServicePlatform'] == 'frameworkcontroller':
        print_normal('Setting frameworkcontroller config...')
        config_result, err_msg = set_frameworkcontroller_config(
            experiment_config, args.port, config_file_name)
        if config_result:
            print_normal('Successfully set frameworkcontroller config!')
        else:
            if err_msg:
                print_error('Failed! Error is: {}'.format(err_msg))
            try:
                kill_command(rest_process.pid)
            except Exception:
                raise Exception(ERROR_INFO % 'Restful server stopped!')
            exit(1)

    # start a new experiment
    print_normal('Starting experiment...')
    # set debug configuration
    if experiment_config.get('debug') is None:
        experiment_config['debug'] = args.debug
    response = set_experiment(experiment_config, mode, args.port,
                              config_file_name)
    if response:
        if experiment_id is None:
            experiment_id = json.loads(response.text).get('experiment_id')
        nni_config.set_config('experimentId', experiment_id)
    else:
        print_error('Start experiment failed!')
        print_log_content(config_file_name)
        try:
            kill_command(rest_process.pid)
        except Exception:
            raise Exception(ERROR_INFO % 'Restful server stopped!')
        exit(1)
    if experiment_config.get('nniManagerIp'):
        web_ui_url_list = [
            '{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))
        ]
    else:
        web_ui_url_list = get_local_urls(args.port)
    nni_config.set_config('webuiUrl', web_ui_url_list)

    #save experiment information
    nnictl_experiment_config = Experiments()
    nnictl_experiment_config.add_experiment(
        experiment_id, args.port, start_time, config_file_name,
        experiment_config['trainingServicePlatform'])

    print_normal(EXPERIMENT_SUCCESS_INFO %
                 (experiment_id, '   '.join(web_ui_url_list)))