コード例 #1
0
ファイル: graphics.py プロジェクト: Rqcker/model-card-toolkit
def annotate_dataset_feature_statistics_plots(
        model_card: model_card_module.ModelCard,
        data_stats: Sequence[statistics_pb2.DatasetFeatureStatisticsList]
) -> None:
    """Annotates visualizations for every dataset and feature.

  This function adds a new Dataset object at model_card.model_parameters.data
  for every dataset in data_stats. For every feature, histograms are created
  and encoded as base64 text strings. They can be found in the Dataset.graphics
  field.

  Args:
    model_card: The model card object.
    data_stats: A list of DatasetFeatureStatisticsList related to the dataset.
  """
    colors = (_COLOR_PALETTE['material_teal_700'],
              _COLOR_PALETTE['material_indigo_400'])
    for stats, color in zip(data_stats, colors):
        if not stats:
            continue
        for dataset in stats.datasets:
            graphs = []
            for feature in dataset.features:
                graph = _extract_graph_data_from_dataset_feature_statistics(
                    feature, color)
                graph = _draw_histogram(graph)
                if graph is not None:
                    graphs.append(
                        model_card_module.Graphic(name=graph.name,
                                                  image=graph.base64str))
            model_card.model_parameters.data.append(
                model_card_module.Dataset(
                    name=dataset.name,
                    graphics=model_card_module.GraphicsCollection(
                        collection=graphs)))
コード例 #2
0
def annotate_dataset_feature_statistics_plots(
        model_card: model_card_module.ModelCard,
        data_stats: Sequence[statistics_pb2.DatasetFeatureStatisticsList]
) -> None:
    """Annotate visualizations for every dataset and feature in train/eval_stats.

  The visualizations are histograms encoded as base64 text strings.

  Args:
    model_card: The model card object.
    data_stats: a list of DatasetFeatureStatisticsList related to the dataset.

  Returns:
    None
  """
    colors = (_COLOR_PALETTE['material_teal_700'],
              _COLOR_PALETTE['material_indigo_400'])
    for stats, color in zip(data_stats, colors):
        if not stats:
            continue
        graphs = []
        for dataset in stats.datasets:
            for feature in dataset.features:
                graph = _generate_graph_from_feature_statistics(feature, color)
                graph = _draw_histogram(graph)
                if graph is not None:
                    graphs.append(
                        model_card_module.Graphic(name=graph.name,
                                                  image=graph.base64str))
        model_card.model_parameters.data.append(
            model_card_module.Dataset(
                graphics=model_card_module.GraphicsCollection(
                    collection=graphs)))
コード例 #3
0
ファイル: graphics.py プロジェクト: Rqcker/model-card-toolkit
def annotate_eval_result_plots(model_card: model_card_module.ModelCard,
                               eval_result: tfma.EvalResult) -> None:
    """Annotates visualizations for every metric in eval_result.

  This function generates barcharts for sliced metrics, encoded as base64 text
  strings, and appends them to
  model_card.quantitative_analysis.graphics.collection.

  Args:
    model_card: The model card object.
    eval_result: A `tfma.EvalResult`.
  """

    # get all metric and slice names
    metrics = set()
    slices_keys = set()
    for slicing_metric in eval_result.slicing_metrics:
        slices_key, _ = stringify_slice_key(slicing_metric[0])
        if slices_key != 'Overall':
            slices_keys.add(slices_key)
        for output_name in slicing_metric[1]:
            for sub_key in slicing_metric[1][output_name]:
                metrics.update(slicing_metric[1][output_name][sub_key].keys())

    # generate barcharts based on metrics and slices
    graphs = []
    if not slices_keys:
        slices_keys.add('')
    for metric in metrics:
        for slices_key in slices_keys:
            graph = _extract_graph_data_from_slicing_metrics(
                eval_result.slicing_metrics, metric, slices_key)
            graph = _draw_histogram(graph)
            if graph is not None:
                graphs.append(graph)

    # annotate model_card with generated graphs
    model_card.quantitative_analysis.graphics.collection.extend([
        model_card_module.Graphic(name=graph.name, image=graph.base64str)
        for graph in graphs
    ])
コード例 #4
0
def annotate_dataset_feature_statistics_plots(
        model_card: model_card_module.ModelCard,
        train_stats: statistics_pb2.DatasetFeatureStatisticsList = None,
        eval_stats: statistics_pb2.DatasetFeatureStatisticsList = None
) -> None:
    """Annotate visualizations for every dataset and feature in eval_stats.

  The visualizations are histograms encoded as base64 text strings.

  Args:
    model_card: The model card object.
    train_stats: a DatasetFeatureStatisticsList corresponding to the training
      dataset.
    eval_stats: a DatasetFeatureStatisticsList corresponding to the eval
      dataset.

  Returns:
    None
  """

    data_stats = (train_stats, eval_stats)
    mc_datasets = (model_card.model_parameters.data.train,
                   model_card.model_parameters.data.eval)
    colors = (_COLOR_PALETTE['material_teal_700'],
              _COLOR_PALETTE['material_indigo_400'])
    for stats, mc_data, color in zip(data_stats, mc_datasets, colors):
        if not stats:
            continue
        graphs = []
        for dataset in stats.datasets:
            for feature in dataset.features:
                graph = _generate_graph_from_feature_statistics(feature, color)
                if graph is not None:
                    _draw_histogram(graph)
                    graphs.append(graph)
        mc_data.graphics.collection.extend([
            model_card_module.Graphic(name=graph.name, image=graph.base64str)
            for graph in graphs
        ])
コード例 #5
0
def annotate_eval_result_plots(model_card: model_card_module.ModelCard,
                               eval_result: tfma.EvalResult):
    """Annotate visualizations for every metric in eval_result.

  The visualizations are barcharts encoded as base64 text strings.

  Args:
    model_card: The model card object.
    eval_result: a `tfma.EvalResult`.
  """

    # TODO(b/159058592): replace with `metrics = eval_result.get_metrics()`
    metrics = set()
    slices_keys = set()
    for slicing_metric in eval_result.slicing_metrics:
        slices_key, _ = stringify_slice_key(slicing_metric[0])
        if slices_key != 'Overall':
            slices_keys.add(slices_key)
        for output_name in slicing_metric[1]:
            for sub_key in slicing_metric[1][output_name]:
                metrics.update(slicing_metric[1][output_name][sub_key].keys())

    graphs = []
    if not slices_keys:
        slices_keys.add('')
    for metric in metrics:
        for slices_key in slices_keys:
            graph = _generate_graph_from_slicing_metrics(
                eval_result.slicing_metrics, metric, slices_key)
            if graph is not None:
                _draw_histogram(graph)
                graphs.append(graph)

    model_card.quantitative_analysis.graphics.collection.extend([
        model_card_module.Graphic(name=graph.name, image=graph.base64str)
        for graph in graphs
    ])
コード例 #6
0
    def test_scaffold_assets_with_source(self, output_file_format: str,
                                         artifacts: bool):
        if artifacts:
            connection_config = metadata_store_pb2.ConnectionConfig()
            connection_config.fake_database.SetInParent()
            mlmd_store = mlmd.MetadataStore(connection_config)
        else:
            mlmd_store = None

        train_dataset_name = 'Dataset-Split-train'
        train_features = ['feature_name1']
        eval_dataset_name = 'Dataset-Split-eval'
        eval_features = ['feature_name2']

        tfma_path = os.path.join(self.tmpdir, 'tfma')
        tfdv_path = os.path.join(self.tmpdir, 'tfdv')
        pushed_model_path = os.path.join(self.tmpdir, 'pushed_model')
        self._write_tfma(tfma_path, output_file_format, mlmd_store)
        self._write_tfdv(tfdv_path, train_dataset_name, train_features,
                         eval_dataset_name, eval_features, mlmd_store)

        if artifacts:
            model_evaluation_artifacts = mlmd_store.get_artifacts_by_type(
                standard_artifacts.ModelEvaluation.TYPE_NAME)
            example_statistics_artifacts = mlmd_store.get_artifacts_by_type(
                standard_artifacts.ExampleStatistics.TYPE_NAME)
            pushed_model_artifact = standard_artifacts.PushedModel()
            pushed_model_artifact.uri = pushed_model_path
            tfma_src = src.TfmaSource(
                model_evaluation_artifacts=model_evaluation_artifacts,
                metrics_exclude=['average_loss'])
            tfdv_src = src.TfdvSource(
                example_statistics_artifacts=example_statistics_artifacts,
                features_include=['feature_name1'])
            model_src = src.ModelSource(
                pushed_model_artifact=pushed_model_artifact)
        else:
            tfma_src = src.TfmaSource(eval_result_paths=[tfma_path],
                                      metrics_exclude=['average_loss'])
            tfdv_src = src.TfdvSource(dataset_statistics_paths=[tfdv_path],
                                      features_include=['feature_name1'])
            model_src = src.ModelSource(pushed_model_path=pushed_model_path)

        mc = model_card_toolkit.ModelCardToolkit(source=src.Source(
            tfma=tfma_src, tfdv=tfdv_src, model=model_src)).scaffold_assets()

        with self.subTest(name='quantitative_analysis'):
            list_to_proto = lambda lst: [x.to_proto() for x in lst]
            expected_performance_metrics = [
                model_card.PerformanceMetric(
                    type='post_export_metrics/example_count', value='2.0')
            ]
            self.assertCountEqual(
                list_to_proto(mc.quantitative_analysis.performance_metrics),
                list_to_proto(expected_performance_metrics))
            self.assertLen(mc.quantitative_analysis.graphics.collection, 1)

        with self.subTest(name='model_parameters.data'):
            self.assertLen(mc.model_parameters.data, 2)  # train and eval
            for dataset in mc.model_parameters.data:
                for graphic in dataset.graphics.collection:
                    self.assertIsNotNone(
                        graphic.image,
                        msg=
                        f'No image found for graphic: {dataset.name} {graphic.name}'
                    )
                    graphic.image = None  # ignore graphic.image for below assertions
            self.assertIn(
                model_card.Dataset(
                    name=train_dataset_name,
                    graphics=model_card.GraphicsCollection(collection=[
                        model_card.Graphic(name='counts | feature_name1')
                    ])), mc.model_parameters.data)
            self.assertNotIn(
                model_card.Dataset(
                    name=eval_dataset_name,
                    graphics=model_card.GraphicsCollection(collection=[
                        model_card.Graphic(name='counts | feature_name2')
                    ])), mc.model_parameters.data)

        with self.subTest(name='model_details.path'):
            self.assertEqual(mc.model_details.path, pushed_model_path)