def create(context, **kwargs): repo_type = context.parent.command.name entity_name = kwargs['artifact_name'] wizard_flag = kwargs['wizard'] check_entity_name(entity_name) _verify_project_settings(wizard_flag, context, repo_type, entity_name, check_entity=False) kwargs['categories'] = wizard_for_field(context, kwargs['categories'], prompt_msg.CATEGORIES_MESSAGE, wizard_flag=wizard_flag, required=True, type=CategoriesType()) if not kwargs['categories']: raise UsageError( output_messages['ERROR_MISSING_OPTION'].format('categories')) kwargs['mutability'] = choice_wizard_for_field( context, kwargs['mutability'], prompt_msg.MUTABILITY_MESSAGE, click.Choice(MutabilityType.to_list()), default=None, wizard_flag=wizard_flag) if not kwargs['mutability']: raise UsageError( output_messages['ERROR_MISSING_OPTION'].format('mutability')) repositories[repo_type].create(kwargs)
def checkout(context, **kwargs): repo_type = context.parent.command.name repo = repositories[repo_type] entity = kwargs['ml_entity_tag'] wizard_flag = kwargs['wizard'] _verify_project_settings(wizard_flag, context, repo_type, entity, check_entity=False) sample = None if 'sample_type' in kwargs and kwargs['sample_type'] is not None: sample = { kwargs['sample_type']: kwargs['sampling'], 'seed': kwargs['seed'] } options = {} version = kwargs['version'] if not re.search(RGX_TAG_FORMAT, entity): version = wizard_for_field( context, version, VERSION_TO_BE_DOWNLOADED.format( parse_entity_type_to_singular(repo_type)), wizard_flag=wizard_flag, type=click.IntRange(0, MAX_INT_VALUE)) if version is None: version = -1 options['version'] = version options['with_dataset'] = kwargs.get('with_dataset', False) if not options['with_dataset'] and (repo_type == MODELS or repo_type == LABELS): options['with_dataset'] = request_user_confirmation( prompt_msg.CHECKOUT_RELATED_ENTITY.format( parse_entity_type_to_singular(DATASETS)), wizard_flag=wizard_flag) options['with_labels'] = kwargs.get('with_labels', False) if not options['with_labels'] and repo_type == MODELS: options['with_labels'] = request_user_confirmation( prompt_msg.CHECKOUT_RELATED_ENTITY.format(LABELS), wizard_flag=wizard_flag) options['retry'] = kwargs['retry'] options['force'] = kwargs['force'] options['bare'] = kwargs['bare'] options['fail_limit'] = kwargs['fail_limit'] options['full'] = kwargs['full'] repo.checkout(entity, sample, options)
def check_empty_values(ctx, param, value): value_present = value is not None value_empty = str(value).strip() == '' if value_present else False if value_present and value_empty: local_enabled = 'wizard' in ctx.params and ctx.params['wizard'] if local_enabled or is_wizard_enabled(): error_message = output_messages['ERROR_INVALID_VALUE_FOR'] % ( ''.join(["--", param.name ]), output_messages['ERROR_EMPTY_VALUE']) return wizard_for_field(ctx, None, '{}\n{}'.format(error_message, prompt_msg.NEW_VALUE), wizard_flag=local_enabled, default=''), True raise click.BadParameter(output_messages['ERROR_EMPTY_VALUE']) return value, False
def add(context, **kwargs): repo_type = context.parent.command.name entity_name = kwargs['ml_entity_name'] wizard_flag = False if 'wizard' in kwargs: wizard_flag = kwargs['wizard'] _verify_project_settings(wizard_flag, context, repo_type, entity_name) bump_version = kwargs['bumpversion'] run_fsck = kwargs['fsck'] file_path = kwargs['file_path'] metric = kwargs.get('metric') metrics_file_path = kwargs.get('metrics_file') if not metric and repo_type == MODELS: metrics_file_path = wizard_for_field(context, kwargs.get('metrics_file'), prompt_msg.METRIC_FILE, wizard_flag=wizard_flag) repositories[repo_type].add(entity_name, file_path, bump_version, run_fsck, metric, metrics_file_path)
def handle_parse_result(self, ctx, opts, args): using_required_option = self.name in opts using_dependent_options = all( opt.replace('-', '_') in opts for opt in self.required_option) option_name = self.name.replace('_', '-') if not using_required_option and using_dependent_options: msg = output_messages['ERROR_REQUIRED_OPTION_MISSING'].format( option_name, ', '.join(self.required_option), option_name) if not is_wizard_enabled(): raise MissingParameter(ctx=ctx, param=self, message=msg) requested_value = wizard_for_field(ctx, None, msg, required=True) opts[self.name] = requested_value return super(OptionRequiredIf, self).handle_parse_result(ctx, opts, args) elif using_required_option and not using_dependent_options: log.warn(output_messages['WARN_USELESS_OPTION'].format( option_name, ', '.join(self.required_option))) return super(OptionRequiredIf, self).handle_parse_result(ctx, opts, args)
def check_integer_value(ctx, param, value): value_present = value is not None if value_present: try: return int(value), False except ValueError: local_enabled = 'wizard' in ctx.params and ctx.params['wizard'] if local_enabled or is_wizard_enabled(): error_message = output_messages['ERROR_INVALID_VALUE_FOR'] % ( ''.join(["--", param.name]), output_messages['ERROR_NOT_INTEGER_VALUE'].format(value)) return wizard_for_field(ctx, None, '{}\n{}'.format( error_message, prompt_msg.NEW_VALUE), wizard_flag=local_enabled, type=int, default=''), True raise click.BadParameter( output_messages['ERROR_NOT_INTEGER_VALUE'].format(value)) return value, False
def commit(context, **kwargs): wizard_flag = False if 'wizard' in kwargs: wizard_flag = kwargs['wizard'] repo_type = context.parent.command.name entity_name = kwargs['ml_entity_name'] _verify_project_settings(wizard_flag, context, repo_type, entity_name) run_fsck = kwargs['fsck'] if not repositories[repo_type].has_data_to_commit(entity_name): context.exit() last_version = get_last_entity_version(repo_type, entity_name) version = wizard_for_field(context, kwargs['version'], prompt_msg.COMMIT_VERSION.format( parse_entity_type_to_singular(repo_type), last_version), wizard_flag=wizard_flag, type=click.IntRange(0, MAX_INT_VALUE), default=last_version) msg = wizard_for_field(context, kwargs['message'], prompt_msg.COMMIT_MESSAGE, wizard_flag=wizard_flag) related_entities = {} linked_dataset_key = parse_entity_type_to_singular(DATASETS) if repo_type == MODELS: if kwargs[linked_dataset_key] is not None: related_entities[ EntityType.DATASETS.value] = kwargs[linked_dataset_key] elif request_user_confirmation( prompt_msg.WANT_LINK_TO_MODEL_ENTITY.format( linked_dataset_key, parse_entity_type_to_singular(MODELS)), wizard_flag=wizard_flag): related_entities[EntityType.DATASETS.value] = wizard_for_field( context, kwargs[linked_dataset_key], prompt_msg.DEFINE_LINKED_DATASET, required=True, wizard_flag=wizard_flag) if kwargs[EntityType.LABELS.value] is not None: related_entities[EntityType.LABELS.value] = kwargs[ EntityType.LABELS.value] elif request_user_confirmation( prompt_msg.WANT_LINK_TO_MODEL_ENTITY.format(LABELS), wizard_flag=wizard_flag): related_entities[EntityType.LABELS.value] = wizard_for_field( context, kwargs[EntityType.LABELS.value], prompt_msg.DEFINE_LINKED_LABELS, required=True, wizard_flag=wizard_flag) elif repo_type == LABELS: if kwargs[linked_dataset_key] is not None: related_entities[ EntityType.DATASETS.value] = kwargs[linked_dataset_key] elif request_user_confirmation( prompt_msg.WANT_LINK_TO_LABEL_ENTITY.format( linked_dataset_key), wizard_flag=wizard_flag): related_entities[EntityType.DATASETS.value] = wizard_for_field( context, kwargs[linked_dataset_key], prompt_msg.DEFINE_LINKED_DATASET, required=True, wizard_flag=wizard_flag) repositories[repo_type].commit(entity_name, related_entities, version, run_fsck, msg)
def storage_add(context, **kwargs): check_project_exists(context) wizard_flag = kwargs['wizard'] kwargs['type'] = choice_wizard_for_field( context, kwargs['type'], prompt_msg.STORAGE_TYPE_MESSAGE, click.Choice(MultihashStorageType.to_list()), default=StorageType.S3H.value, wizard_flag=wizard_flag) if kwargs['type'] == StorageType.S3H.value: admin.storage_add(kwargs['type'], kwargs['bucket_name'], wizard_for_field(context, kwargs['credentials'], CREDENTIALS_PROFILE_MESSAGE, wizard_flag=wizard_flag), global_conf=kwargs.get('global', False), endpoint_url=wizard_for_field( context, kwargs['endpoint_url'], ENDPOINT_MESSAGE, wizard_flag=wizard_flag), region=wizard_for_field(context, kwargs['region'], REGION_MESSAGE, wizard_flag=wizard_flag)) elif kwargs['type'] == StorageType.GDRIVEH.value: admin.storage_add(kwargs['type'], kwargs['bucket_name'], wizard_for_field(context, kwargs['credentials'], CREDENTIALS_PATH_MESSAGE, wizard_flag=wizard_flag), global_conf=kwargs.get('global', False)) elif kwargs['type'] == StorageType.SFTPH.value: sftp_configs = { 'username': wizard_for_field(context, kwargs['username'], USERNAME_SFTPH, wizard_flag=wizard_flag), 'private_key': wizard_for_field(context, kwargs['private_key'], PRIVATE_KEY_SFTPH, wizard_flag=wizard_flag), 'port': wizard_for_field(context, kwargs['port'], PORT_SFTPH, wizard_flag=wizard_flag, default=22, type=int) } admin.storage_add(kwargs['type'], kwargs['bucket_name'], kwargs['credentials'], global_conf=kwargs.get('global', False), endpoint_url=wizard_for_field( context, kwargs['endpoint_url'], SFTPH_ENDPOINT_MESSAGE, wizard_flag=wizard_flag), sftp_configs=sftp_configs) else: admin.storage_add(kwargs['type'], kwargs['bucket_name'], kwargs['credentials'], global_conf=kwargs.get('global', False))