def launch_experiment(args, experiment_config, mode, experiment_id, config_version): '''follow steps to start rest server and start experiment''' # check packages for tuner package_name, module_name = None, None if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'): package_name = experiment_config['tuner']['builtinTunerName'] module_name, _ = get_builtin_module_class_name('tuners', package_name) elif experiment_config.get('advisor') and experiment_config['advisor'].get('builtinAdvisorName'): package_name = experiment_config['advisor']['builtinAdvisorName'] module_name, _ = get_builtin_module_class_name('advisors', package_name) if package_name and module_name: try: stdout_full_path, stderr_full_path = get_log_path(experiment_id) with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file) except CalledProcessError: print_error('some errors happen when import package %s.' %(package_name)) print_log_content(experiment_id) if package_name in ['SMAC', 'BOHB', 'PPOTuner']: print_error(f'The dependencies for {package_name} can be installed through pip install nni[{package_name}]') raise if config_version == 1: log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else NNI_HOME_DIR else: log_dir = experiment_config['experimentWorkingDirectory'] if experiment_config.get('experimentWorkingDirectory') else NNI_HOME_DIR log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None #view experiment mode do not need debug function, when view an experiment, there will be no new logs created foreground = False if mode != 'view': foreground = args.foreground if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True): log_level = 'debug' # start rest server if config_version == 1: platform = experiment_config['trainingServicePlatform'] elif isinstance(experiment_config['trainingService'], list): platform = 'hybrid' else: platform = experiment_config['trainingService']['platform'] rest_process, start_time = start_rest_server(args.port, platform, \ mode, experiment_id, foreground, log_dir, log_level) # save experiment information Experiments().add_experiment(experiment_id, args.port, start_time, platform, experiment_config.get('experimentName', 'N/A'), pid=rest_process.pid, logDir=log_dir) # Deal with annotation if experiment_config.get('useAnnotation'): path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation') if not os.path.isdir(path): os.makedirs(path) path = tempfile.mkdtemp(dir=path) nas_mode = experiment_config['trial'].get('nasMode', 'classic_mode') code_dir = expand_annotations(experiment_config['trial']['codeDir'], path, nas_mode=nas_mode) experiment_config['trial']['codeDir'] = code_dir search_space = generate_search_space(code_dir) experiment_config['searchSpace'] = search_space assert search_space, ERROR_INFO % 'Generated search space is empty' elif config_version == 1: if experiment_config.get('searchSpacePath'): search_space = get_json_content(experiment_config.get('searchSpacePath')) experiment_config['searchSpace'] = search_space else: experiment_config['searchSpace'] = '' # check rest server running, _ = check_rest_server(args.port) if running: print_normal('Successfully started Restful server!') else: print_error('Restful server start failed!') print_log_content(experiment_id) try: kill_command(rest_process.pid) except Exception: raise Exception(ERROR_INFO % 'Rest server stopped!') exit(1) if config_version == 1 and mode != 'view': # set platform configuration set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\ experiment_id, rest_process) # start a new experiment print_normal('Starting experiment...') # set debug configuration if mode != 'view' and experiment_config.get('debug') is None: experiment_config['debug'] = args.debug if config_version == 1: response = set_experiment_v1(experiment_config, mode, args.port, experiment_id) else: response = set_experiment_v2(experiment_config, mode, args.port, experiment_id) if response: if experiment_id is None: experiment_id = json.loads(response.text).get('experiment_id') else: print_error('Start experiment failed!') print_log_content(experiment_id) try: kill_command(rest_process.pid) except Exception: raise Exception(ERROR_INFO % 'Restful server stopped!') exit(1) if experiment_config.get('nniManagerIp'): web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))] else: web_ui_url_list = get_local_urls(args.port) Experiments().update_experiment(experiment_id, 'webuiUrl', web_ui_url_list) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) if mode != 'view' and args.foreground: try: while True: log_content = rest_process.stdout.readline().strip().decode('utf-8') print(log_content) except KeyboardInterrupt: kill_command(rest_process.pid) print_normal('Stopping experiment...')
def create_experiment(args): # to make it clear what are inside args config_file = Path(args.config) port = args.port debug = args.debug url_prefix = args.url_prefix foreground = args.foreground # it should finally be done in nnictl main function # but for now don't break routines without logging support init_logger_for_command_line() logging.getLogger('nni').setLevel(logging.INFO) if not config_file.is_file(): _logger.error(f'"{config_file}" is not a valid file.') exit(1) with config_file.open() as config: config_content = yaml.safe_load(config) v1_platform = config_content.get('trainingServicePlatform') if v1_platform: can_convert = True if v1_platform == 'adl': can_convert = False if v1_platform in ['kubeflow', 'frameworkcontroller']: reuse = config_content.get(v1_platform + 'Config', {}).get('reuse') can_convert = ( reuse != False ) # if user does not explicitly specify it, convert to reuse mode if not can_convert: legacy_launcher.create_experiment(args) exit() try: v2_config = convert.to_v2(config_content) except Exception: _logger.error( 'You are using legacy config format with incorrect fields or values, ' 'to get more accurate error message please update it to the new format.' ) _logger.error( 'Reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html' ) exit(1) _logger.warning( f'You are using legacy config file, please update it to latest format:' ) # use `print` here because logging will add timestamp and make it hard to copy paste print(Fore.YELLOW + '=' * 80 + Fore.RESET) print(yaml.dump(v2_config, sort_keys=False).strip()) print(Fore.YELLOW + '=' * 80 + Fore.RESET) print( Fore.YELLOW + 'Reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html' + Fore.RESET) utils.set_base_path(config_file.parent) config = ExperimentConfig(**v2_config) utils.unset_base_path() else: config = ExperimentConfig.load(config_file) if config.use_annotation: path = Path(tempfile.gettempdir(), getuser(), 'nni', 'annotation') path.mkdir(parents=True, exist_ok=True) path = tempfile.mkdtemp(dir=path) code_dir = expand_annotations(config.trial_code_directory, path) config.trial_code_directory = code_dir config.search_space = generate_search_space(code_dir) assert config.search_space, 'ERROR: Generated search space is empty' config.use_annotation = False exp = Experiment(config) exp.url_prefix = url_prefix run_mode = RunMode.Foreground if foreground else RunMode.Detach exp.start(port, debug, run_mode) _logger.info( f'To stop experiment run "nnictl stop {exp.id}" or "nnictl stop --all"' ) _logger.info( 'Reference: https://nni.readthedocs.io/en/stable/Tutorial/Nnictl.html')
def test_search_space_generator(): search_space = annotation.generate_search_space(cwd / '_generated/annotated') expected = json.load((cwd / 'testcase/searchspace.json').open()) assert search_space == expected
def launch_experiment(args, experiment_config, mode, experiment_id): '''follow steps to start rest server and start experiment''' nni_config = Config(experiment_id) # check packages for tuner package_name, module_name = None, None if experiment_config.get('tuner') and experiment_config['tuner'].get( 'builtinTunerName'): package_name = experiment_config['tuner']['builtinTunerName'] module_name, _ = get_builtin_module_class_name('tuners', package_name) elif experiment_config.get('advisor') and experiment_config['advisor'].get( 'builtinAdvisorName'): package_name = experiment_config['advisor']['builtinAdvisorName'] module_name, _ = get_builtin_module_class_name('advisors', package_name) if package_name and module_name: try: stdout_full_path, stderr_full_path = get_log_path(experiment_id) with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: check_call([sys.executable, '-c', 'import %s' % (module_name)], stdout=stdout_file, stderr=stderr_file) except CalledProcessError: print_error('some errors happen when import package %s.' % (package_name)) print_log_content(experiment_id) if package_name in INSTALLABLE_PACKAGE_META: print_error('If %s is not installed, it should be installed through '\ '\'nnictl package install --name %s\'' % (package_name, package_name)) exit(1) log_dir = experiment_config['logDir'] if experiment_config.get( 'logDir') else None log_level = experiment_config['logLevel'] if experiment_config.get( 'logLevel') else None #view experiment mode do not need debug function, when view an experiment, there will be no new logs created foreground = False if mode != 'view': foreground = args.foreground if log_level not in [ 'trace', 'debug' ] and (args.debug or experiment_config.get('debug') is True): log_level = 'debug' # start rest server rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \ mode, experiment_id, foreground, log_dir, log_level) nni_config.set_config('restServerPid', rest_process.pid) # Deal with annotation if experiment_config.get('useAnnotation'): path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation') if not os.path.isdir(path): os.makedirs(path) path = tempfile.mkdtemp(dir=path) nas_mode = experiment_config['trial'].get('nasMode', 'classic_mode') code_dir = expand_annotations(experiment_config['trial']['codeDir'], path, nas_mode=nas_mode) experiment_config['trial']['codeDir'] = code_dir search_space = generate_search_space(code_dir) experiment_config['searchSpace'] = json.dumps(search_space) assert search_space, ERROR_INFO % 'Generated search space is empty' elif experiment_config.get('searchSpacePath'): search_space = get_json_content( experiment_config.get('searchSpacePath')) experiment_config['searchSpace'] = json.dumps(search_space) else: experiment_config['searchSpace'] = json.dumps('') # check rest server running, _ = check_rest_server(args.port) if running: print_normal('Successfully started Restful server!') else: print_error('Restful server start failed!') print_log_content(experiment_id) try: kill_command(rest_process.pid) except Exception: raise Exception(ERROR_INFO % 'Rest server stopped!') exit(1) if mode != 'view': # set platform configuration set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\ experiment_id, rest_process) # start a new experiment print_normal('Starting experiment...') # save experiment information nnictl_experiment_config = Experiments() nnictl_experiment_config.add_experiment( experiment_id, args.port, start_time, experiment_config['trainingServicePlatform'], experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir) # set debug configuration if mode != 'view' and experiment_config.get('debug') is None: experiment_config['debug'] = args.debug response = set_experiment(experiment_config, mode, args.port, experiment_id) if response: if experiment_id is None: experiment_id = json.loads(response.text).get('experiment_id') else: print_error('Start experiment failed!') print_log_content(experiment_id) try: kill_command(rest_process.pid) except Exception: raise Exception(ERROR_INFO % 'Restful server stopped!') exit(1) if experiment_config.get('nniManagerIp'): web_ui_url_list = [ '{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port)) ] else: web_ui_url_list = get_local_urls(args.port) nni_config.set_config('webuiUrl', web_ui_url_list) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) if mode != 'view' and args.foreground: try: while True: log_content = rest_process.stdout.readline().strip().decode( 'utf-8') print(log_content) except KeyboardInterrupt: kill_command(rest_process.pid) print_normal('Stopping experiment...')