Пример #1
0
def wait_training_finish(timeout: int, wait: bool, mt_id: str,
                         mt_client: ModelTrainingClient):
    """
    Wait for training to finish according to command line arguments

    :param wait:
    :param timeout:
    :param mt_id: Model Training name
    :param mt_client: Model Training Client
    """
    if not wait:
        return

    start = time.time()
    if timeout <= 0:
        raise Exception(
            'Invalid --timeout argument: should be positive integer')

    # We create a separate client for logs because it has the different timeout settings
    log_mt_client = ModelTrainingClient.construct_from_other(mt_client)
    log_mt_client.timeout = mt_client.timeout, LOG_READ_TIMEOUT_SECONDS

    click.echo("Logs streaming...")

    while True:
        elapsed = time.time() - start
        if elapsed > timeout:
            raise Exception(TIMEOUT_ERROR_MESSAGE)

        try:
            mt = mt_client.get(mt_id)
            if mt.status.state == TRAINING_SUCCESS_STATE:
                click.echo(
                    f'Model {mt_id} was trained. Training took {round(time.time() - start)} seconds'
                )
                return
            elif mt.status.state == TRAINING_FAILED_STATE:
                raise Exception(f'Model training {mt_id} was failed.')
            elif mt.status.state == "":
                click.echo(
                    f"Can't determine the state of {mt.id}. Sleeping...")
            else:
                for msg in log_mt_client.log(mt.id, follow=True):
                    print_logs(msg)

        except (WrongHttpStatusCode, HTTPException, RequestException,
                APIConnectionException) as e:
            LOGGER.info(
                'Callback have not confirmed completion of the operation. Exception: %s',
                str(e))

        LOGGER.debug('Sleep before next request')
        time.sleep(DEFAULT_WAIT_TIMEOUT)
Пример #2
0
def run(client: ModelTrainingClient, train_id: str, manifest_file: List[str],
        manifest_dir: List[str], output_dir: str):
    """
    \b
    Start a training process locally.
    \b
    Usage example:
        * odahuflowctl local train run --id examples-git
    \f
    """
    entities: List[OdahuflowCloudResourceUpdatePair] = []
    for file_path in manifest_file:
        entities.extend(parse_resources_file(file_path).changes)

    for dir_path in manifest_dir:
        entities.extend(parse_resources_dir(dir_path))

    mt: Optional[ModelTraining] = None

    # find a training
    toolchains: Dict[str, ToolchainIntegration] = {}
    for entity in map(lambda x: x.resource, entities):
        if isinstance(entity, ToolchainIntegration):
            toolchains[entity.id] = entity
        elif isinstance(entity, ModelTraining) and entity.id == train_id:
            mt = entity

    if not mt:
        click.echo(
            f'{train_id} training not found. Trying to retrieve it from API server'
        )
        mt = client.get(train_id)

    toolchain = toolchains.get(mt.spec.toolchain)
    if not toolchain:
        click.echo(
            f'{toolchain} toolchain not found. Trying to retrieve it from API server'
        )
        toolchain = ToolchainIntegrationClient.construct_from_other(
            client).get(mt.spec.toolchain)

    trainer = K8sTrainer(
        model_training=mt,
        toolchain_integration=toolchain,
    )

    start_train(trainer, output_dir)
Пример #3
0
def get(client: ModelTrainingClient, train_id: str, output_format: str):
    """
    \b
    Get trainings.
    The command without id argument retrieve all trainings.
    \b
    Get all trainings in json format:
        odahuflowctl train get --output-format json
    \b
    Get training with "git-repo" id:
        odahuflowctl train get --id git-repo
    \b
    Using jsonpath:
        odahuflowctl train get -o 'jsonpath=[*].spec.reference'
    \f
    :param client: Model training HTTP client
    :param train_id: Model training ID
    :param output_format: Output format
    :return:
    """
    trains = [client.get(train_id)] if train_id else client.get_all()

    format_output(trains, output_format)