示例#1
0
def create(context, **kwargs):
    repo_type = context.parent.command.name
    entity_name = kwargs['artifact_name']
    wizard_flag = kwargs['wizard']
    check_entity_name(entity_name)
    _verify_project_settings(wizard_flag,
                             context,
                             repo_type,
                             entity_name,
                             check_entity=False)
    kwargs['categories'] = wizard_for_field(context,
                                            kwargs['categories'],
                                            prompt_msg.CATEGORIES_MESSAGE,
                                            wizard_flag=wizard_flag,
                                            required=True,
                                            type=CategoriesType())
    if not kwargs['categories']:
        raise UsageError(
            output_messages['ERROR_MISSING_OPTION'].format('categories'))
    kwargs['mutability'] = choice_wizard_for_field(
        context,
        kwargs['mutability'],
        prompt_msg.MUTABILITY_MESSAGE,
        click.Choice(MutabilityType.to_list()),
        default=None,
        wizard_flag=wizard_flag)
    if not kwargs['mutability']:
        raise UsageError(
            output_messages['ERROR_MISSING_OPTION'].format('mutability'))
    repositories[repo_type].create(kwargs)
示例#2
0
def checkout(context, **kwargs):
    repo_type = context.parent.command.name
    repo = repositories[repo_type]
    entity = kwargs['ml_entity_tag']
    wizard_flag = kwargs['wizard']
    _verify_project_settings(wizard_flag,
                             context,
                             repo_type,
                             entity,
                             check_entity=False)
    sample = None

    if 'sample_type' in kwargs and kwargs['sample_type'] is not None:
        sample = {
            kwargs['sample_type']: kwargs['sampling'],
            'seed': kwargs['seed']
        }
    options = {}

    version = kwargs['version']
    if not re.search(RGX_TAG_FORMAT, entity):
        version = wizard_for_field(
            context,
            version,
            VERSION_TO_BE_DOWNLOADED.format(
                parse_entity_type_to_singular(repo_type)),
            wizard_flag=wizard_flag,
            type=click.IntRange(0, MAX_INT_VALUE))
    if version is None:
        version = -1
    options['version'] = version

    options['with_dataset'] = kwargs.get('with_dataset', False)
    if not options['with_dataset'] and (repo_type == MODELS
                                        or repo_type == LABELS):
        options['with_dataset'] = request_user_confirmation(
            prompt_msg.CHECKOUT_RELATED_ENTITY.format(
                parse_entity_type_to_singular(DATASETS)),
            wizard_flag=wizard_flag)
    options['with_labels'] = kwargs.get('with_labels', False)
    if not options['with_labels'] and repo_type == MODELS:
        options['with_labels'] = request_user_confirmation(
            prompt_msg.CHECKOUT_RELATED_ENTITY.format(LABELS),
            wizard_flag=wizard_flag)
    options['retry'] = kwargs['retry']
    options['force'] = kwargs['force']
    options['bare'] = kwargs['bare']
    options['fail_limit'] = kwargs['fail_limit']
    options['full'] = kwargs['full']
    repo.checkout(entity, sample, options)
示例#3
0
def check_empty_values(ctx, param, value):
    value_present = value is not None
    value_empty = str(value).strip() == '' if value_present else False
    if value_present and value_empty:
        local_enabled = 'wizard' in ctx.params and ctx.params['wizard']
        if local_enabled or is_wizard_enabled():
            error_message = output_messages['ERROR_INVALID_VALUE_FOR'] % (
                ''.join(["--", param.name
                         ]), output_messages['ERROR_EMPTY_VALUE'])
            return wizard_for_field(ctx,
                                    None,
                                    '{}\n{}'.format(error_message,
                                                    prompt_msg.NEW_VALUE),
                                    wizard_flag=local_enabled,
                                    default=''), True
        raise click.BadParameter(output_messages['ERROR_EMPTY_VALUE'])
    return value, False
示例#4
0
def add(context, **kwargs):
    repo_type = context.parent.command.name
    entity_name = kwargs['ml_entity_name']
    wizard_flag = False
    if 'wizard' in kwargs:
        wizard_flag = kwargs['wizard']
    _verify_project_settings(wizard_flag, context, repo_type, entity_name)
    bump_version = kwargs['bumpversion']
    run_fsck = kwargs['fsck']
    file_path = kwargs['file_path']
    metric = kwargs.get('metric')
    metrics_file_path = kwargs.get('metrics_file')
    if not metric and repo_type == MODELS:
        metrics_file_path = wizard_for_field(context,
                                             kwargs.get('metrics_file'),
                                             prompt_msg.METRIC_FILE,
                                             wizard_flag=wizard_flag)
    repositories[repo_type].add(entity_name, file_path, bump_version, run_fsck,
                                metric, metrics_file_path)
示例#5
0
 def handle_parse_result(self, ctx, opts, args):
     using_required_option = self.name in opts
     using_dependent_options = all(
         opt.replace('-', '_') in opts for opt in self.required_option)
     option_name = self.name.replace('_', '-')
     if not using_required_option and using_dependent_options:
         msg = output_messages['ERROR_REQUIRED_OPTION_MISSING'].format(
             option_name, ', '.join(self.required_option), option_name)
         if not is_wizard_enabled():
             raise MissingParameter(ctx=ctx, param=self, message=msg)
         requested_value = wizard_for_field(ctx, None, msg, required=True)
         opts[self.name] = requested_value
         return super(OptionRequiredIf,
                      self).handle_parse_result(ctx, opts, args)
     elif using_required_option and not using_dependent_options:
         log.warn(output_messages['WARN_USELESS_OPTION'].format(
             option_name, ', '.join(self.required_option)))
     return super(OptionRequiredIf,
                  self).handle_parse_result(ctx, opts, args)
示例#6
0
def check_integer_value(ctx, param, value):
    value_present = value is not None
    if value_present:
        try:
            return int(value), False
        except ValueError:
            local_enabled = 'wizard' in ctx.params and ctx.params['wizard']
            if local_enabled or is_wizard_enabled():
                error_message = output_messages['ERROR_INVALID_VALUE_FOR'] % (
                    ''.join(["--", param.name]),
                    output_messages['ERROR_NOT_INTEGER_VALUE'].format(value))
                return wizard_for_field(ctx,
                                        None,
                                        '{}\n{}'.format(
                                            error_message,
                                            prompt_msg.NEW_VALUE),
                                        wizard_flag=local_enabled,
                                        type=int,
                                        default=''), True
            raise click.BadParameter(
                output_messages['ERROR_NOT_INTEGER_VALUE'].format(value))
    return value, False
示例#7
0
def commit(context, **kwargs):
    wizard_flag = False
    if 'wizard' in kwargs:
        wizard_flag = kwargs['wizard']
    repo_type = context.parent.command.name
    entity_name = kwargs['ml_entity_name']
    _verify_project_settings(wizard_flag, context, repo_type, entity_name)
    run_fsck = kwargs['fsck']

    if not repositories[repo_type].has_data_to_commit(entity_name):
        context.exit()

    last_version = get_last_entity_version(repo_type, entity_name)
    version = wizard_for_field(context,
                               kwargs['version'],
                               prompt_msg.COMMIT_VERSION.format(
                                   parse_entity_type_to_singular(repo_type),
                                   last_version),
                               wizard_flag=wizard_flag,
                               type=click.IntRange(0, MAX_INT_VALUE),
                               default=last_version)
    msg = wizard_for_field(context,
                           kwargs['message'],
                           prompt_msg.COMMIT_MESSAGE,
                           wizard_flag=wizard_flag)

    related_entities = {}
    linked_dataset_key = parse_entity_type_to_singular(DATASETS)
    if repo_type == MODELS:
        if kwargs[linked_dataset_key] is not None:
            related_entities[
                EntityType.DATASETS.value] = kwargs[linked_dataset_key]
        elif request_user_confirmation(
                prompt_msg.WANT_LINK_TO_MODEL_ENTITY.format(
                    linked_dataset_key, parse_entity_type_to_singular(MODELS)),
                wizard_flag=wizard_flag):
            related_entities[EntityType.DATASETS.value] = wizard_for_field(
                context,
                kwargs[linked_dataset_key],
                prompt_msg.DEFINE_LINKED_DATASET,
                required=True,
                wizard_flag=wizard_flag)

        if kwargs[EntityType.LABELS.value] is not None:
            related_entities[EntityType.LABELS.value] = kwargs[
                EntityType.LABELS.value]
        elif request_user_confirmation(
                prompt_msg.WANT_LINK_TO_MODEL_ENTITY.format(LABELS),
                wizard_flag=wizard_flag):
            related_entities[EntityType.LABELS.value] = wizard_for_field(
                context,
                kwargs[EntityType.LABELS.value],
                prompt_msg.DEFINE_LINKED_LABELS,
                required=True,
                wizard_flag=wizard_flag)
    elif repo_type == LABELS:
        if kwargs[linked_dataset_key] is not None:
            related_entities[
                EntityType.DATASETS.value] = kwargs[linked_dataset_key]
        elif request_user_confirmation(
                prompt_msg.WANT_LINK_TO_LABEL_ENTITY.format(
                    linked_dataset_key),
                wizard_flag=wizard_flag):
            related_entities[EntityType.DATASETS.value] = wizard_for_field(
                context,
                kwargs[linked_dataset_key],
                prompt_msg.DEFINE_LINKED_DATASET,
                required=True,
                wizard_flag=wizard_flag)

    repositories[repo_type].commit(entity_name, related_entities, version,
                                   run_fsck, msg)
示例#8
0
def storage_add(context, **kwargs):
    check_project_exists(context)
    wizard_flag = kwargs['wizard']
    kwargs['type'] = choice_wizard_for_field(
        context,
        kwargs['type'],
        prompt_msg.STORAGE_TYPE_MESSAGE,
        click.Choice(MultihashStorageType.to_list()),
        default=StorageType.S3H.value,
        wizard_flag=wizard_flag)
    if kwargs['type'] == StorageType.S3H.value:
        admin.storage_add(kwargs['type'],
                          kwargs['bucket_name'],
                          wizard_for_field(context,
                                           kwargs['credentials'],
                                           CREDENTIALS_PROFILE_MESSAGE,
                                           wizard_flag=wizard_flag),
                          global_conf=kwargs.get('global', False),
                          endpoint_url=wizard_for_field(
                              context,
                              kwargs['endpoint_url'],
                              ENDPOINT_MESSAGE,
                              wizard_flag=wizard_flag),
                          region=wizard_for_field(context,
                                                  kwargs['region'],
                                                  REGION_MESSAGE,
                                                  wizard_flag=wizard_flag))
    elif kwargs['type'] == StorageType.GDRIVEH.value:
        admin.storage_add(kwargs['type'],
                          kwargs['bucket_name'],
                          wizard_for_field(context,
                                           kwargs['credentials'],
                                           CREDENTIALS_PATH_MESSAGE,
                                           wizard_flag=wizard_flag),
                          global_conf=kwargs.get('global', False))
    elif kwargs['type'] == StorageType.SFTPH.value:
        sftp_configs = {
            'username':
            wizard_for_field(context,
                             kwargs['username'],
                             USERNAME_SFTPH,
                             wizard_flag=wizard_flag),
            'private_key':
            wizard_for_field(context,
                             kwargs['private_key'],
                             PRIVATE_KEY_SFTPH,
                             wizard_flag=wizard_flag),
            'port':
            wizard_for_field(context,
                             kwargs['port'],
                             PORT_SFTPH,
                             wizard_flag=wizard_flag,
                             default=22,
                             type=int)
        }
        admin.storage_add(kwargs['type'],
                          kwargs['bucket_name'],
                          kwargs['credentials'],
                          global_conf=kwargs.get('global', False),
                          endpoint_url=wizard_for_field(
                              context,
                              kwargs['endpoint_url'],
                              SFTPH_ENDPOINT_MESSAGE,
                              wizard_flag=wizard_flag),
                          sftp_configs=sftp_configs)
    else:
        admin.storage_add(kwargs['type'],
                          kwargs['bucket_name'],
                          kwargs['credentials'],
                          global_conf=kwargs.get('global', False))