示例#1
0
    def update_model_card_json(
            self, model_card: model_card_module.ModelCard) -> None:
        """Validates the model card and updates the JSON file in MCT assets.

    If model_card.schema_version is not provided, it will assign the latest
    schema version to the `model_card`, and validate it.

    Args:
      model_card: The updated model card that users want to write back.

    Raises:
       Error: when the given model_card is invalid w.r.t. the schema.
    """
        if not model_card.schema_version:
            sub_directories = [
                f for f in os.scandir(_SCHEMA_DIR) if f.is_dir()
            ]
            latest_schema_version = max(
                sub_directories,
                key=lambda f: semantic_version.Version(f.name[1:]))
            model_card.schema_version = latest_schema_version.name[1:]
        # Validate the updated model_card first.
        schema = self._find_model_card_schema(model_card.schema_version)
        jsonschema.validate(model_card.to_dict(), schema)
        # Write the updated JSON to the file.
        self._write_file(self._mcta_json_file, model_card.to_json())
示例#2
0
    def scaffold_assets(self) -> ModelCard:
        """Generates the model cards tookit assets.

    Model cards assets include the model card json file and customizable model
    card UI templates.

    An assets directory is created if one does not already exist.

    If the MCT is initialized with a `mlmd_store`, it further auto-populates
    the model cards properties as well as generating related plots such as model
    performance and data distributions.

    Returns:
      A ModelCard representing the given model.

    Raises:
      FileNotFoundError: if it failed to copy the UI template files.
    """
        model_card = ModelCard()
        if self._store:
            model_card = tfx_util.generate_model_card_for_model(
                self._store, self._artifact_with_model_uri.id)
            metrics_artifacts = tfx_util.get_metrics_artifacts_for_model(
                self._store, self._artifact_with_model_uri.id)
            stats_artifacts = tfx_util.get_stats_artifacts_for_model(
                self._store, self._artifact_with_model_uri.id)

            for metrics_artifact in metrics_artifacts:
                eval_result = tfx_util.read_metrics_eval_result(
                    metrics_artifact.uri)
                if eval_result is not None:
                    graphics.annotate_eval_result_plots(
                        model_card, eval_result)

            for stats_artifact in stats_artifacts:
                train_stats = tfx_util.read_stats_proto(
                    stats_artifact.uri, 'train')
                eval_stats = tfx_util.read_stats_proto(stats_artifact.uri,
                                                       'eval')
                graphics.annotate_dataset_feature_statistics_plots(
                    model_card, train_stats, eval_stats)

        # Write JSON file.
        self._write_file(self._mcta_json_file, model_card.to_json())

        # Write UI template files.
        for template_path in _UI_TEMPLATES:
            template_content = pkgutil.get_data('model_card_toolkit',
                                                template_path)
            if template_content is None:
                raise FileNotFoundError(f"Cannot find file: '{template_path}'")
            template_content = template_content.decode('utf8')
            self._write_file(os.path.join(self.output_dir, template_path),
                             template_content)

        return model_card
示例#3
0
    def update_model_card_json(self, model_card: ModelCard) -> None:
        """Validates the model card and updates the JSON file in MCT assets.

    If model_card.schema_version is not provided, it will assign the latest
    schema version to the `model_card`, and validate it.

    Args:
      model_card: The updated model card that users want to write back.

    Raises:
       Error: when the given model_card is invalid w.r.t. the schema.
    """
        if not model_card.schema_version:
            model_card.schema_version = validation.get_latest_schema_version()
        validation.validate_json_schema(model_card.to_dict(),
                                        model_card.schema_version)
        self._write_file(self._mcta_json_file, model_card.to_json())