def test_export_format(self):
        store = testdata_utils.get_tfx_pipeline_metadata_store(
            self.tmp_db_path)
        mct = model_card_toolkit.ModelCardToolkit(
            output_dir=self.tmpdir,
            mlmd_source=src.MlmdSource(
                store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI))
        mc = mct.scaffold_assets()
        mc.model_details.name = 'My Model'
        mct.update_model_card(mc)
        result = mct.export_format()

        proto_path = os.path.join(self.tmpdir, 'data/model_card.proto')
        self.assertTrue(os.path.exists(proto_path))
        with open(proto_path, 'rb') as f:
            model_card_proto = model_card_pb2.ModelCard()
            model_card_proto.ParseFromString(f.read())
            self.assertEqual(model_card_proto.model_details.name, 'My Model')
        model_card_path = os.path.join(self.tmpdir,
                                       'model_cards/model_card.html')
        self.assertTrue(os.path.exists(model_card_path))
        with open(model_card_path) as f:
            content = f.read()
            self.assertEqual(content, result)
            self.assertTrue(content.startswith('<!DOCTYPE html>'))
            self.assertIn('My Model', content)
예제 #2
0
 def test_scaffold_assets_with_store(self):
     output_dir = self.tmpdir
     store = testdata_utils.get_tfx_pipeline_metadata_store(
         self.tmp_db_path)
     mct = model_card_toolkit.ModelCardToolkit(
         output_dir=output_dir,
         mlmd_store=store,
         model_uri=testdata_utils.TFX_0_21_MODEL_URI)
     mc = mct.scaffold_assets()
     self.assertIsNotNone(mc.model_details.name)
     self.assertIsNotNone(mc.model_details.version.name)
     self.assertNotEmpty(mc.quantitative_analysis.graphics.collection)
     self.assertIn(
         'average_loss', {
             graphic.name
             for graphic in mc.quantitative_analysis.graphics.collection
         })
     self.assertIn(
         'post_export_metrics/example_count', {
             graphic.name
             for graphic in mc.quantitative_analysis.graphics.collection
         })
     self.assertIn('default_template.html.jinja',
                   os.listdir(os.path.join(output_dir, 'template/html')))
     self.assertIn('default_template.md.jinja',
                   os.listdir(os.path.join(output_dir, 'template/md')))
예제 #3
0
 def test_init_with_store_no_model_uri(self):
     store = testdata_utils.get_tfx_pipeline_metadata_store(
         self.tmp_db_path)
     with self.assertRaisesRegex(
             ValueError,
             'If `mlmd_store` is set, `model_uri` should be set.'):
         model_card_toolkit.ModelCardToolkit(output_dir=self.tmpdir,
                                             mlmd_store=store)
 def test_scaffold_assets_with_invalid_tfdv_source(self):
     with self.assertRaisesWithLiteralMatch(
             ValueError, 'Only one of TfdvSource.features_include and '
             'TfdvSource.features_exclude should be set.'):
         model_card_toolkit.ModelCardToolkit(source=src.Source(
             tfdv=src.TfdvSource(dataset_statistics_paths=['dummy/path'],
                                 features_include=['brand_confidence'],
                                 features_exclude=['brand_prominence'])))
 def test_scaffold_assets_with_invalid_tfma_source(self):
     with self.assertRaisesWithLiteralMatch(
             ValueError,
             'Only one of TfmaSource.metrics_include and TfmaSource.metrics_exclude '
             'should be set.'):
         model_card_toolkit.ModelCardToolkit(source=src.Source(
             tfma=src.TfmaSource(eval_result_paths=['dummy/path'],
                                 metrics_include=['false_positive_rate'],
                                 metrics_exclude=['false_negative_rate'])))
 def test_init_with_store_model_uri_not_found(self):
     store = testdata_utils.get_tfx_pipeline_metadata_store(
         self.tmp_db_path)
     unknown_model = 'unknown_model'
     with self.assertRaisesRegex(
             ValueError,
             f'"{unknown_model}" cannot be found in the `store`'):
         model_card_toolkit.ModelCardToolkit(mlmd_source=src.MlmdSource(
             store=store, model_uri=unknown_model))
예제 #7
0
 def test_scaffold_assets(self):
     output_dir = self.tmpdir
     mct = model_card_toolkit.ModelCardToolkit(output_dir=output_dir)
     self.assertEqual(mct.output_dir, output_dir)
     mc = mct.scaffold_assets()  # pylint: disable=unused-variable
     self.assertIn('default_template.html.jinja',
                   os.listdir(os.path.join(output_dir, 'template/html')))
     self.assertIn('default_template.md.jinja',
                   os.listdir(os.path.join(output_dir, 'template/md')))
    def test_update_model_card_with_valid_model_card(self):
        mct = model_card_toolkit.ModelCardToolkit(output_dir=self.tmpdir)
        valid_model_card = mct.scaffold_assets()
        valid_model_card.model_details.name = 'My Model'
        mct.update_model_card(valid_model_card)
        proto_path = os.path.join(self.tmpdir, 'data/model_card.proto')

        model_card_proto = model_card_pb2.ModelCard()
        with open(proto_path, 'rb') as f:
            model_card_proto.ParseFromString(f.read())
        self.assertEqual(model_card_proto, valid_model_card.to_proto())
예제 #9
0
 def test_update_model_card_with_no_version(self):
     mct = model_card_toolkit.ModelCardToolkit()
     model_card_no_version = mct.scaffold_assets()
     model_card_no_version.model_details.name = ('My ' 'Model')
     mct.update_model_card_json(model_card_no_version)
     json_path = os.path.join(mct.output_dir, 'data/model_card.json')
     with open(json_path) as f:
         self.assertEqual(
             json.loads(f.read()),
             model_card_no_version.to_dict(),
         )
예제 #10
0
 def test_update_model_card_with_valid_json(self):
     mct = model_card_toolkit.ModelCardToolkit(output_dir=self.tmpdir)
     valid_model_card = mct.scaffold_assets()
     valid_model_card.schema_version = '0.0.1'
     valid_model_card.model_details.name = 'My Model'
     mct.update_model_card_json(valid_model_card)
     json_path = os.path.join(self.tmpdir, 'data/model_card.json')
     with open(json_path) as f:
         self.assertEqual(
             json.loads(f.read()),
             valid_model_card.to_dict(),
         )
    def test_export_format_with_customized_template_and_output_name(self):
        mct = model_card_toolkit.ModelCardToolkit(output_dir=self.tmpdir)
        mc = mct.scaffold_assets()
        mc.model_details.name = 'My Model'
        mct.update_model_card(mc)

        template_path = os.path.join(
            self.tmpdir, 'template/html/default_template.html.jinja')
        output_file = 'my_model_card.html'
        result = mct.export_format(template_path=template_path,
                                   output_file=output_file)

        model_card_path = os.path.join(self.tmpdir, 'model_cards', output_file)
        self.assertTrue(os.path.exists(model_card_path))
        with open(model_card_path) as f:
            content = f.read()
            self.assertEqual(content, result)
            self.assertTrue(content.startswith('<!DOCTYPE html>'))
            self.assertIn('My Model', content)
예제 #12
0
    def test_export_format(self):
        store = testdata_utils.get_tfx_pipeline_metadata_store(
            self.tmp_db_path)
        mct = model_card_toolkit.ModelCardToolkit(
            output_dir=self.tmpdir,
            mlmd_store=store,
            model_uri=testdata_utils.TFX_0_21_MODEL_URI)
        model_card = mct.scaffold_assets()
        model_card.schema_version = '0.0.1'
        model_card.model_details.name = 'My Model'
        mct.update_model_card_json(model_card)
        result = mct.export_format()

        model_card_path = os.path.join(self.tmpdir,
                                       'model_cards/model_card.html')
        self.assertTrue(os.path.exists(model_card_path))
        with open(model_card_path) as f:
            content = f.read()
            self.assertEqual(content, result)
            self.assertTrue(content.startswith('<!DOCTYPE html>'))
            self.assertIn('My Model', content)
 def test_scaffold_assets_with_store(self, mock_annotate_data_stats,
                                     mock_annotate_eval_results):
     num_stat_artifacts = 2
     num_eval_artifacts = 1
     output_dir = self.tmpdir
     store = testdata_utils.get_tfx_pipeline_metadata_store(
         self.tmp_db_path)
     mct = model_card_toolkit.ModelCardToolkit(
         output_dir=output_dir,
         mlmd_source=src.MlmdSource(
             store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI))
     mc = mct.scaffold_assets()
     self.assertIsNotNone(mc.model_details.name)
     self.assertIsNotNone(mc.model_details.version.name)
     self.assertIn('default_template.html.jinja',
                   os.listdir(os.path.join(output_dir, 'template/html')))
     self.assertIn('default_template.md.jinja',
                   os.listdir(os.path.join(output_dir, 'template/md')))
     self.assertEqual(mock_annotate_data_stats.call_count,
                      num_stat_artifacts)
     self.assertEqual(mock_annotate_eval_results.call_count,
                      num_eval_artifacts)
 def test_export_format_before_scaffold_assets(self):
     with self.assertRaises(ValueError):
         model_card_toolkit.ModelCardToolkit().export_format()
    def test_scaffold_assets_with_source(self, output_file_format: str,
                                         artifacts: bool):
        if artifacts:
            connection_config = metadata_store_pb2.ConnectionConfig()
            connection_config.fake_database.SetInParent()
            mlmd_store = mlmd.MetadataStore(connection_config)
        else:
            mlmd_store = None

        train_dataset_name = 'Dataset-Split-train'
        train_features = ['feature_name1']
        eval_dataset_name = 'Dataset-Split-eval'
        eval_features = ['feature_name2']

        tfma_path = os.path.join(self.tmpdir, 'tfma')
        tfdv_path = os.path.join(self.tmpdir, 'tfdv')
        pushed_model_path = os.path.join(self.tmpdir, 'pushed_model')
        self._write_tfma(tfma_path, output_file_format, mlmd_store)
        self._write_tfdv(tfdv_path, train_dataset_name, train_features,
                         eval_dataset_name, eval_features, mlmd_store)

        if artifacts:
            model_evaluation_artifacts = mlmd_store.get_artifacts_by_type(
                standard_artifacts.ModelEvaluation.TYPE_NAME)
            example_statistics_artifacts = mlmd_store.get_artifacts_by_type(
                standard_artifacts.ExampleStatistics.TYPE_NAME)
            pushed_model_artifact = standard_artifacts.PushedModel()
            pushed_model_artifact.uri = pushed_model_path
            tfma_src = src.TfmaSource(
                model_evaluation_artifacts=model_evaluation_artifacts,
                metrics_exclude=['average_loss'])
            tfdv_src = src.TfdvSource(
                example_statistics_artifacts=example_statistics_artifacts,
                features_include=['feature_name1'])
            model_src = src.ModelSource(
                pushed_model_artifact=pushed_model_artifact)
        else:
            tfma_src = src.TfmaSource(eval_result_paths=[tfma_path],
                                      metrics_exclude=['average_loss'])
            tfdv_src = src.TfdvSource(dataset_statistics_paths=[tfdv_path],
                                      features_include=['feature_name1'])
            model_src = src.ModelSource(pushed_model_path=pushed_model_path)

        mc = model_card_toolkit.ModelCardToolkit(source=src.Source(
            tfma=tfma_src, tfdv=tfdv_src, model=model_src)).scaffold_assets()

        with self.subTest(name='quantitative_analysis'):
            list_to_proto = lambda lst: [x.to_proto() for x in lst]
            expected_performance_metrics = [
                model_card.PerformanceMetric(
                    type='post_export_metrics/example_count', value='2.0')
            ]
            self.assertCountEqual(
                list_to_proto(mc.quantitative_analysis.performance_metrics),
                list_to_proto(expected_performance_metrics))
            self.assertLen(mc.quantitative_analysis.graphics.collection, 1)

        with self.subTest(name='model_parameters.data'):
            self.assertLen(mc.model_parameters.data, 2)  # train and eval
            for dataset in mc.model_parameters.data:
                for graphic in dataset.graphics.collection:
                    self.assertIsNotNone(
                        graphic.image,
                        msg=
                        f'No image found for graphic: {dataset.name} {graphic.name}'
                    )
                    graphic.image = None  # ignore graphic.image for below assertions
            self.assertIn(
                model_card.Dataset(
                    name=train_dataset_name,
                    graphics=model_card.GraphicsCollection(collection=[
                        model_card.Graphic(name='counts | feature_name1')
                    ])), mc.model_parameters.data)
            self.assertNotIn(
                model_card.Dataset(
                    name=eval_dataset_name,
                    graphics=model_card.GraphicsCollection(collection=[
                        model_card.Graphic(name='counts | feature_name2')
                    ])), mc.model_parameters.data)

        with self.subTest(name='model_details.path'):
            self.assertEqual(mc.model_details.path, pushed_model_path)
예제 #16
0
 def test_update_model_card_with_invalid_schema_version(self):
     mct = model_card_toolkit.ModelCardToolkit()
     model_card_invalid_version = model_card_module.ModelCard(
         schema_version='100.0.0')
     with self.assertRaisesRegex(ValueError, 'Cannot find schema version'):
         mct.update_model_card_json(model_card_invalid_version)
 def test_scaffold_assets_with_empty_source(self):
     model_card_toolkit.ModelCardToolkit(
         source=src.Source()).scaffold_assets()