def check_lineage_type(self, data): """Check lineage type.""" SearchModelConditionParameter.check_operation(data) SearchModelConditionParameter.check_dict_value_type(data, str) recv_types = [] for key, value in data.items(): if key == "in": recv_types = value else: recv_types.append(value) lineage_types = enum_to_list(LineageType) if not set(recv_types).issubset(lineage_types): raise ValidationError("Given lineage type should be one of %s." % lineage_types)
def _get_lineage_types(self, lineage_type_param): """ Get lineage types. Args: lineage_type_param (dict): A dict contains "in" or "eq". Returns: list, lineage type. """ # lineage_type_param is None or an empty dict if not lineage_type_param: return enum_to_list(LineageType) if lineage_type_param.get("in") is not None: return lineage_type_param.get("in") return [lineage_type_param.get("eq")]
class SearchModelConditionParameter(Schema): """Define the search model condition parameter schema.""" summary_dir = fields.Dict() loss_function = fields.Dict() train_dataset_path = fields.Dict() train_dataset_count = fields.Dict() test_dataset_path = fields.Dict() test_dataset_count = fields.Dict() network = fields.Dict() optimizer = fields.Dict() learning_rate = fields.Dict() epoch = fields.Dict() batch_size = fields.Dict() loss = fields.Dict() model_size = fields.Dict() limit = fields.Int(validate=lambda n: 0 < n <= 100) offset = fields.Int(validate=lambda n: 0 <= n <= 100000) sorted_name = fields.Str() sorted_type = fields.Str(allow_none=True) lineage_type = fields.Str(validate=OneOf(enum_to_list(LineageType)), allow_none=True) @staticmethod def check_dict_value_type(data, value_type): """Check dict value type and int scope.""" for key, value in data.items(): if key == "in": if not isinstance(value, (list, tuple)): raise ValidationError( "In operation's value must be list or tuple.") else: if not isinstance(value, value_type): raise ValidationError("Wrong value type.") if value_type is int: if value < 0 or value > pow(2, 63) - 1: raise ValidationError( "Int value should <= pow(2, 63) - 1.") if isinstance(value, bool): raise ValidationError("Wrong value type.") @staticmethod def check_param_value_type(data): """Check input param's value type.""" for key, value in data.items(): if key == "in": if not isinstance(value, (list, tuple)): raise ValidationError( "In operation's value must be list or tuple.") else: if isinstance(value, bool) or \ (not isinstance(value, float) and not isinstance(value, int)): raise ValidationError("Wrong value type.") @validates("loss") def check_loss(self, data): """Check loss.""" SearchModelConditionParameter.check_param_value_type(data) @validates("learning_rate") def check_learning_rate(self, data): """Check learning_rate.""" SearchModelConditionParameter.check_param_value_type(data) @validates("loss_function") def check_loss_function(self, data): SearchModelConditionParameter.check_dict_value_type(data, str) @validates("train_dataset_path") def check_train_dataset_path(self, data): SearchModelConditionParameter.check_dict_value_type(data, str) @validates("train_dataset_count") def check_train_dataset_count(self, data): SearchModelConditionParameter.check_dict_value_type(data, int) @validates("test_dataset_path") def check_test_dataset_path(self, data): SearchModelConditionParameter.check_dict_value_type(data, str) @validates("test_dataset_count") def check_test_dataset_count(self, data): SearchModelConditionParameter.check_dict_value_type(data, int) @validates("network") def check_network(self, data): SearchModelConditionParameter.check_dict_value_type(data, str) @validates("optimizer") def check_optimizer(self, data): SearchModelConditionParameter.check_dict_value_type(data, str) @validates("epoch") def check_epoch(self, data): SearchModelConditionParameter.check_dict_value_type(data, int) @validates("batch_size") def check_batch_size(self, data): SearchModelConditionParameter.check_dict_value_type(data, int) @validates("model_size") def check_model_size(self, data): SearchModelConditionParameter.check_dict_value_type(data, int) @validates("summary_dir") def check_summary_dir(self, data): SearchModelConditionParameter.check_dict_value_type(data, str) @pre_load def check_comparision(self, data, **kwargs): """Check comparision for all parameters in schema.""" for attr, condition in data.items(): if attr in [ "limit", "offset", "sorted_name", "sorted_type", "lineage_type" ]: continue if not isinstance(attr, str): raise LineageParamValueError( 'The search attribute not supported.') if attr not in FIELD_MAPPING and not attr.startswith('metric_'): raise LineageParamValueError( 'The search attribute not supported.') if not isinstance(condition, dict): raise LineageParamTypeError( "The search_condition element {} should be dict.".format( attr)) for key in condition.keys(): if key not in ["eq", "lt", "gt", "le", "ge", "in"]: raise LineageParamValueError( "The compare condition should be in " "('eq', 'lt', 'gt', 'le', 'ge', 'in').") if attr.startswith('metric_'): if len(attr) == 7: raise LineageParamValueError( 'The search attribute not supported.') try: SearchModelConditionParameter.check_param_value_type( condition) except ValidationError: raise MindInsightException( error=LineageErrors.LINEAGE_PARAM_METRIC_ERROR, message=LineageErrorMsg.LINEAGE_METRIC_ERROR.value. format(attr)) return data