Exemplo n.º 1
0
def create_experiment(args):
    '''start a new experiment'''
    experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
    config_path = os.path.abspath(args.config)
    if not os.path.exists(config_path):
        print_error('Please set correct config path!')
        exit(1)
    config_yml = get_yml_content(config_path)

    if 'trainingServicePlatform' in config_yml:
        _validate_v1(config_yml, config_path)
        platform = config_yml['trainingServicePlatform']
        if platform in k8s_training_services:
            schema = 1
            config_v1 = config_yml
        else:
            schema = 2
            from nni.experiment.config import convert
            config_v2 = convert.to_v2(config_yml).json()
    else:
        config_v2 = _validate_v2(config_yml, config_path)
        schema = 2

    try:
        if schema == 1:
            launch_experiment(args, config_v1, 'new', experiment_id, 1)
        else:
            launch_experiment(args, config_v2, 'new', experiment_id, 2)
    except Exception as exception:
        restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
        if restServerPid:
            kill_command(restServerPid)
        print_error(exception)
        exit(1)
Exemplo n.º 2
0
def create_experiment(args):
    '''start a new experiment'''
    experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
    config_path = os.path.abspath(args.config)
    if not os.path.exists(config_path):
        print_error('Please set correct config path!')
        exit(1)
    config_yml = get_yml_content(config_path)

    try:
        config = ExperimentConfig(_base_path=Path(config_path).parent, **config_yml)
        config_v2 = config.json()
    except Exception as error_v2:
        print_warning('Validation with V2 schema failed. Trying to convert from V1 format...')
        try:
            validate_all_content(config_yml, config_path)
        except Exception as error_v1:
            print_error(f'Convert from v1 format failed: {repr(error_v1)}')
            print_error(f'Config in v2 format validation failed: {repr(error_v2)}')
            exit(1)
        from nni.experiment.config import convert
        config_v2 = convert.to_v2(config_yml).json()

    try:
        if getattr(config_v2['trainingService'], 'platform', None) in k8s_training_services:
            launch_experiment(args, config_yml, 'new', experiment_id, 1)
        else:
            launch_experiment(args, config_v2, 'new', experiment_id, 2)
    except Exception as exception:
        restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
        if restServerPid:
            kill_command(restServerPid)
        print_error(exception)
        exit(1)
Exemplo n.º 3
0
def create_experiment(args):
    # to make it clear what are inside args
    config_file = Path(args.config)
    port = args.port
    debug = args.debug
    url_prefix = args.url_prefix
    foreground = args.foreground

    # it should finally be done in nnictl main function
    # but for now don't break routines without logging support
    init_logger_for_command_line()
    logging.getLogger('nni').setLevel(logging.INFO)

    if not config_file.is_file():
        _logger.error(f'"{config_file}" is not a valid file.')
        exit(1)

    with config_file.open() as config:
        config_content = yaml.safe_load(config)

    v1_platform = config_content.get('trainingServicePlatform')
    if v1_platform:
        can_convert = True
        if v1_platform == 'adl':
            can_convert = False
        if v1_platform in ['kubeflow', 'frameworkcontroller']:
            reuse = config_content.get(v1_platform + 'Config', {}).get('reuse')
            can_convert = (
                reuse != False
            )  # if user does not explicitly specify it, convert to reuse mode

        if not can_convert:
            legacy_launcher.create_experiment(args)
            exit()

        try:
            v2_config = convert.to_v2(config_content)
        except Exception:
            _logger.error(
                'You are using legacy config format with incorrect fields or values, '
                'to get more accurate error message please update it to the new format.'
            )
            _logger.error(
                'Reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html'
            )
            exit(1)
        _logger.warning(
            f'You are using legacy config file, please update it to latest format:'
        )
        # use `print` here because logging will add timestamp and make it hard to copy paste
        print(Fore.YELLOW + '=' * 80 + Fore.RESET)
        print(yaml.dump(v2_config, sort_keys=False).strip())
        print(Fore.YELLOW + '=' * 80 + Fore.RESET)
        print(
            Fore.YELLOW +
            'Reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html'
            + Fore.RESET)

        utils.set_base_path(config_file.parent)
        config = ExperimentConfig(**v2_config)
        utils.unset_base_path()

    else:
        config = ExperimentConfig.load(config_file)

    if config.use_annotation:
        path = Path(tempfile.gettempdir(), getuser(), 'nni', 'annotation')
        path.mkdir(parents=True, exist_ok=True)
        path = tempfile.mkdtemp(dir=path)
        code_dir = expand_annotations(config.trial_code_directory, path)
        config.trial_code_directory = code_dir
        config.search_space = generate_search_space(code_dir)
        assert config.search_space, 'ERROR: Generated search space is empty'
        config.use_annotation = False

    exp = Experiment(config)
    exp.url_prefix = url_prefix
    run_mode = RunMode.Foreground if foreground else RunMode.Detach
    exp.start(port, debug, run_mode)

    _logger.info(
        f'To stop experiment run "nnictl stop {exp.id}" or "nnictl stop --all"'
    )
    _logger.info(
        'Reference: https://nni.readthedocs.io/en/stable/Tutorial/Nnictl.html')