def check_lineage_type(self, data):
        """Check lineage type."""
        SearchModelConditionParameter.check_operation(data)
        SearchModelConditionParameter.check_dict_value_type(data, str)
        recv_types = []
        for key, value in data.items():
            if key == "in":
                recv_types = value
            else:
                recv_types.append(value)

        lineage_types = enum_to_list(LineageType)
        if not set(recv_types).issubset(lineage_types):
            raise ValidationError("Given lineage type should be one of %s." % lineage_types)
Beispiel #2
0
    def _get_lineage_types(self, lineage_type_param):
        """
        Get lineage types.

        Args:
            lineage_type_param (dict): A dict contains "in" or "eq".

        Returns:
            list, lineage type.

        """
        # lineage_type_param is None or an empty dict
        if not lineage_type_param:
            return enum_to_list(LineageType)

        if lineage_type_param.get("in") is not None:
            return lineage_type_param.get("in")

        return [lineage_type_param.get("eq")]
class SearchModelConditionParameter(Schema):
    """Define the search model condition parameter schema."""
    summary_dir = fields.Dict()
    loss_function = fields.Dict()
    train_dataset_path = fields.Dict()
    train_dataset_count = fields.Dict()
    test_dataset_path = fields.Dict()
    test_dataset_count = fields.Dict()
    network = fields.Dict()
    optimizer = fields.Dict()
    learning_rate = fields.Dict()
    epoch = fields.Dict()
    batch_size = fields.Dict()
    loss = fields.Dict()
    model_size = fields.Dict()
    limit = fields.Int(validate=lambda n: 0 < n <= 100)
    offset = fields.Int(validate=lambda n: 0 <= n <= 100000)
    sorted_name = fields.Str()
    sorted_type = fields.Str(allow_none=True)
    lineage_type = fields.Str(validate=OneOf(enum_to_list(LineageType)),
                              allow_none=True)

    @staticmethod
    def check_dict_value_type(data, value_type):
        """Check dict value type and int scope."""
        for key, value in data.items():
            if key == "in":
                if not isinstance(value, (list, tuple)):
                    raise ValidationError(
                        "In operation's value must be list or tuple.")
            else:
                if not isinstance(value, value_type):
                    raise ValidationError("Wrong value type.")
                if value_type is int:
                    if value < 0 or value > pow(2, 63) - 1:
                        raise ValidationError(
                            "Int value should <= pow(2, 63) - 1.")
                    if isinstance(value, bool):
                        raise ValidationError("Wrong value type.")

    @staticmethod
    def check_param_value_type(data):
        """Check input param's value type."""
        for key, value in data.items():
            if key == "in":
                if not isinstance(value, (list, tuple)):
                    raise ValidationError(
                        "In operation's value must be list or tuple.")
            else:
                if isinstance(value, bool) or \
                        (not isinstance(value, float) and not isinstance(value, int)):
                    raise ValidationError("Wrong value type.")

    @validates("loss")
    def check_loss(self, data):
        """Check loss."""
        SearchModelConditionParameter.check_param_value_type(data)

    @validates("learning_rate")
    def check_learning_rate(self, data):
        """Check learning_rate."""
        SearchModelConditionParameter.check_param_value_type(data)

    @validates("loss_function")
    def check_loss_function(self, data):
        SearchModelConditionParameter.check_dict_value_type(data, str)

    @validates("train_dataset_path")
    def check_train_dataset_path(self, data):
        SearchModelConditionParameter.check_dict_value_type(data, str)

    @validates("train_dataset_count")
    def check_train_dataset_count(self, data):
        SearchModelConditionParameter.check_dict_value_type(data, int)

    @validates("test_dataset_path")
    def check_test_dataset_path(self, data):
        SearchModelConditionParameter.check_dict_value_type(data, str)

    @validates("test_dataset_count")
    def check_test_dataset_count(self, data):
        SearchModelConditionParameter.check_dict_value_type(data, int)

    @validates("network")
    def check_network(self, data):
        SearchModelConditionParameter.check_dict_value_type(data, str)

    @validates("optimizer")
    def check_optimizer(self, data):
        SearchModelConditionParameter.check_dict_value_type(data, str)

    @validates("epoch")
    def check_epoch(self, data):
        SearchModelConditionParameter.check_dict_value_type(data, int)

    @validates("batch_size")
    def check_batch_size(self, data):
        SearchModelConditionParameter.check_dict_value_type(data, int)

    @validates("model_size")
    def check_model_size(self, data):
        SearchModelConditionParameter.check_dict_value_type(data, int)

    @validates("summary_dir")
    def check_summary_dir(self, data):
        SearchModelConditionParameter.check_dict_value_type(data, str)

    @pre_load
    def check_comparision(self, data, **kwargs):
        """Check comparision for all parameters in schema."""
        for attr, condition in data.items():
            if attr in [
                    "limit", "offset", "sorted_name", "sorted_type",
                    "lineage_type"
            ]:
                continue

            if not isinstance(attr, str):
                raise LineageParamValueError(
                    'The search attribute not supported.')

            if attr not in FIELD_MAPPING and not attr.startswith('metric_'):
                raise LineageParamValueError(
                    'The search attribute not supported.')

            if not isinstance(condition, dict):
                raise LineageParamTypeError(
                    "The search_condition element {} should be dict.".format(
                        attr))

            for key in condition.keys():
                if key not in ["eq", "lt", "gt", "le", "ge", "in"]:
                    raise LineageParamValueError(
                        "The compare condition should be in "
                        "('eq', 'lt', 'gt', 'le', 'ge', 'in').")

            if attr.startswith('metric_'):
                if len(attr) == 7:
                    raise LineageParamValueError(
                        'The search attribute not supported.')
                try:
                    SearchModelConditionParameter.check_param_value_type(
                        condition)
                except ValidationError:
                    raise MindInsightException(
                        error=LineageErrors.LINEAGE_PARAM_METRIC_ERROR,
                        message=LineageErrorMsg.LINEAGE_METRIC_ERROR.value.
                        format(attr))
        return data