예제 #1
0
def update(json_dict: Optional[Dict[Text, Any]] = None) -> Dict[Text, Any]:
    """Updates a Model Card JSON dictionary to the latest schema version.

  If you have a JSON string, you can use it with this function as follows:

  ```python
  json_dict = json.loads(json_text)
  updated_json_dict = json_util.update(json_dict)
  ```

  Args:
    json_dict: A dictionary representing a Model Card JSON object.

  Returns:
    The input Model Card, converted to a JSON string of the latest schema
      version. If the input Model Card already corresponds to schema v0.0.2, it
      is returned unmodified.

  Raises:
    ValidationError: If `json_dict` does not follow the model card JSON v0.0.1
      schema.
  """
    try:
        validation.validate_json_schema(json_dict, "0.0.2")
        logging.info("JSON object already matches schema 0.0.2.")
        return json_dict  # pytype: disable=bad-return-type
    except jsonschema.ValidationError:
        logging.info("JSON object does match schema 0.0.2; updating.")
        return _update_from_v1_to_v2(json_dict)
예제 #2
0
    def _from_json(self, json_dict: Dict[Text, Any]) -> "ModelCard":
        """Read ModelCard from JSON.

    If ModelCard fields have already been set, this function will overwrite any
    existing values.

    WARNING: This method's interface may change in the future, do not use for
    critical workflows.

    Args:
      json_dict: A JSON dict from which to populate fields in the model card
        schema.

    Returns:
      self

    Raises:
      JSONDecodeError: If `json_dict` is not a valid JSON string.
      ValidationError: If `json_dict` does not follow the model card JSON
        schema.
      ValueError: If `json_dict` contains a value not in the class or schema
        definition.
    """
        def _populate_from_json(
                json_dict: Dict[Text, Any],
                field: BaseModelCardField) -> BaseModelCardField:
            for subfield_key in json_dict:
                if subfield_key.startswith(_SCHEMA_VERSION_STRING):
                    continue
                elif not hasattr(field, subfield_key):
                    raise ValueError(
                        "BaseModelCardField %s has no such field named '%s.'" %
                        (field, subfield_key))
                elif isinstance(json_dict[subfield_key], dict):
                    subfield_value = _populate_from_json(
                        json_dict[subfield_key], getattr(field, subfield_key))
                elif isinstance(json_dict[subfield_key], list):
                    subfield_value = []
                    for item in json_dict[subfield_key]:
                        if isinstance(item, dict):
                            new_object = field.__annotations__[
                                subfield_key].__args__[0]()  # pytype: disable=attribute-error
                            subfield_value.append(
                                _populate_from_json(item, new_object))
                        else:  # if primitive
                            subfield_value.append(item)
                else:
                    subfield_value = json_dict[subfield_key]
                setattr(field, subfield_key, subfield_value)
            return field

        validation.validate_json_schema(json_dict)
        self.clear()
        _populate_from_json(json_dict, self)
        return self
예제 #3
0
    def update_model_card_json(self, model_card: ModelCard) -> None:
        """Validates the model card and updates the JSON file in MCT assets.

    If model_card.schema_version is not provided, it will assign the latest
    schema version to the `model_card`, and validate it.

    Args:
      model_card: The updated model card that users want to write back.

    Raises:
       Error: when the given model_card is invalid w.r.t. the schema.
    """
        if not model_card.schema_version:
            model_card.schema_version = validation.get_latest_schema_version()
        validation.validate_json_schema(model_card.to_dict(),
                                        model_card.schema_version)
        self._write_file(self._mcta_json_file, model_card.to_json())
예제 #4
0
 def test_template_test_files(self, file_name):
     template_path = os.path.join("template", "test", file_name)
     json_data = json.loads(
         pkgutil.get_data("model_card_toolkit", template_path))
     validation.validate_json_schema(json_data)
예제 #5
0
 def test_validate_json_schema_invalid_version(self):
     invalid_schema_version = "0.0.3"
     with self.assertRaises(ValueError):
         validation.validate_json_schema(
             _MODEL_CARD_V1_DICT, schema_version=invalid_schema_version)
예제 #6
0
 def test_validate_json_schema_invalid_dict(self):
     invalid_json_dict = {"model_name": "the_greatest_model"}
     with self.assertRaises(jsonschema.ValidationError):
         validation.validate_json_schema(invalid_json_dict)
예제 #7
0
 def test_validate_json_schema(self):
     validation.validate_json_schema(_MODEL_CARD_V1_DICT,
                                     schema_version="0.0.1")
     validation.validate_json_schema(_MODEL_CARD_V2_DICT,
                                     schema_version="0.0.2")
예제 #8
0
def _update_from_v1_to_v2(json_dict: Dict[Text, Any]) -> Dict[Text, Any]:
    """Updates a Model Card JSON v0.0.1 string to v0.0.2.

  Args:
    json_dict: A dictionary representing a Model Card v0.0.1 JSON object.

  Returns:
    The input Model Card, converted to a v0.0.2 JSON string.

  Raises:
      JSONDecodeError: If `json_dict` is not a valid JSON string.
      ValidationError: If `json_dict` does not follow the model card JSON v0.0.1
        schema.
  """

    # Validate input args schema
    validation.validate_json_schema(json_dict, "0.0.1")

    # Update schema version
    json_dict["schema_version"] = validation.get_latest_schema_version()

    # Update model_details
    if json_dict["model_details"]["license"]:
        json_dict["model_details"]["licenses"] = [{
            "custom_text":
            json_dict["model_details"].pop("license")
        }]
    if json_dict["model_details"]["references"]:
        json_dict["model_details"]["references"] = [{
            "reference": reference
        } for reference in json_dict["model_details"]["references"]]
    if json_dict["model_details"]["citation"]:
        json_dict["model_details"]["citations"] = [{
            "citation":
            json_dict["model_details"].pop("citation")
        }]

    # Update model_parameters
    if "model_parameters" in json_dict and "data" in json_dict[
            "model_parameters"]:
        new_data = []
        if "train" in json_dict["model_parameters"]["data"]:
            old_train_data = json_dict["model_parameters"]["data"]["train"]
            if "name" not in old_train_data:
                old_train_data["name"] = "Training Set"
            new_data.append(old_train_data)
        if "eval" in json_dict["model_parameters"]["data"]:
            old_eval_data = json_dict["model_parameters"]["data"]["eval"]
            if "name" not in old_eval_data:
                old_eval_data["name"] = "Validation Set"
            new_data.append(old_eval_data)
        json_dict["model_parameters"]["data"] = new_data

    # Update considerations
    if "considerations" in json_dict and "use_cases" in json_dict[
            "considerations"]:
        json_dict["considerations"]["use_cases"] = [{
            "description": use_case
        } for use_case in json_dict["considerations"]["use_cases"]]
    if "considerations" in json_dict and "users" in json_dict["considerations"]:
        json_dict["considerations"]["users"] = [{
            "description": user
        } for user in json_dict["considerations"]["users"]]
    if "considerations" in json_dict and "limitations" in json_dict[
            "considerations"]:
        json_dict["considerations"]["limitations"] = [{
            "description": limitation
        } for limitation in json_dict["considerations"]["limitations"]]
    if "considerations" in json_dict and "tradeoffs" in json_dict[
            "considerations"]:
        json_dict["considerations"]["tradeoffs"] = [{
            "description": limitation
        } for limitation in json_dict["considerations"]["tradeoffs"]]

    return json_dict
예제 #9
0
 def test_validate_json_schema(self):
     validation.validate_json_schema(_MODEL_CARD_DICT, _SCHEMA_VERSION)