def test_export_format(self): store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) mct = model_card_toolkit.ModelCardToolkit( output_dir=self.tmpdir, mlmd_source=src.MlmdSource( store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI)) mc = mct.scaffold_assets() mc.model_details.name = 'My Model' mct.update_model_card(mc) result = mct.export_format() proto_path = os.path.join(self.tmpdir, 'data/model_card.proto') self.assertTrue(os.path.exists(proto_path)) with open(proto_path, 'rb') as f: model_card_proto = model_card_pb2.ModelCard() model_card_proto.ParseFromString(f.read()) self.assertEqual(model_card_proto.model_details.name, 'My Model') model_card_path = os.path.join(self.tmpdir, 'model_cards/model_card.html') self.assertTrue(os.path.exists(model_card_path)) with open(model_card_path) as f: content = f.read() self.assertEqual(content, result) self.assertTrue(content.startswith('<!DOCTYPE html>')) self.assertIn('My Model', content)
def test_scaffold_assets_with_store(self): output_dir = self.tmpdir store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) mct = model_card_toolkit.ModelCardToolkit( output_dir=output_dir, mlmd_store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI) mc = mct.scaffold_assets() self.assertIsNotNone(mc.model_details.name) self.assertIsNotNone(mc.model_details.version.name) self.assertNotEmpty(mc.quantitative_analysis.graphics.collection) self.assertIn( 'average_loss', { graphic.name for graphic in mc.quantitative_analysis.graphics.collection }) self.assertIn( 'post_export_metrics/example_count', { graphic.name for graphic in mc.quantitative_analysis.graphics.collection }) self.assertIn('default_template.html.jinja', os.listdir(os.path.join(output_dir, 'template/html'))) self.assertIn('default_template.md.jinja', os.listdir(os.path.join(output_dir, 'template/md')))
def test_init_with_store_no_model_uri(self): store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) with self.assertRaisesRegex( ValueError, 'If `mlmd_store` is set, `model_uri` should be set.'): model_card_toolkit.ModelCardToolkit(output_dir=self.tmpdir, mlmd_store=store)
def test_scaffold_assets_with_invalid_tfdv_source(self): with self.assertRaisesWithLiteralMatch( ValueError, 'Only one of TfdvSource.features_include and ' 'TfdvSource.features_exclude should be set.'): model_card_toolkit.ModelCardToolkit(source=src.Source( tfdv=src.TfdvSource(dataset_statistics_paths=['dummy/path'], features_include=['brand_confidence'], features_exclude=['brand_prominence'])))
def test_scaffold_assets_with_invalid_tfma_source(self): with self.assertRaisesWithLiteralMatch( ValueError, 'Only one of TfmaSource.metrics_include and TfmaSource.metrics_exclude ' 'should be set.'): model_card_toolkit.ModelCardToolkit(source=src.Source( tfma=src.TfmaSource(eval_result_paths=['dummy/path'], metrics_include=['false_positive_rate'], metrics_exclude=['false_negative_rate'])))
def test_init_with_store_model_uri_not_found(self): store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) unknown_model = 'unknown_model' with self.assertRaisesRegex( ValueError, f'"{unknown_model}" cannot be found in the `store`'): model_card_toolkit.ModelCardToolkit(mlmd_source=src.MlmdSource( store=store, model_uri=unknown_model))
def test_scaffold_assets(self): output_dir = self.tmpdir mct = model_card_toolkit.ModelCardToolkit(output_dir=output_dir) self.assertEqual(mct.output_dir, output_dir) mc = mct.scaffold_assets() # pylint: disable=unused-variable self.assertIn('default_template.html.jinja', os.listdir(os.path.join(output_dir, 'template/html'))) self.assertIn('default_template.md.jinja', os.listdir(os.path.join(output_dir, 'template/md')))
def test_update_model_card_with_valid_model_card(self): mct = model_card_toolkit.ModelCardToolkit(output_dir=self.tmpdir) valid_model_card = mct.scaffold_assets() valid_model_card.model_details.name = 'My Model' mct.update_model_card(valid_model_card) proto_path = os.path.join(self.tmpdir, 'data/model_card.proto') model_card_proto = model_card_pb2.ModelCard() with open(proto_path, 'rb') as f: model_card_proto.ParseFromString(f.read()) self.assertEqual(model_card_proto, valid_model_card.to_proto())
def test_update_model_card_with_no_version(self): mct = model_card_toolkit.ModelCardToolkit() model_card_no_version = mct.scaffold_assets() model_card_no_version.model_details.name = ('My ' 'Model') mct.update_model_card_json(model_card_no_version) json_path = os.path.join(mct.output_dir, 'data/model_card.json') with open(json_path) as f: self.assertEqual( json.loads(f.read()), model_card_no_version.to_dict(), )
def test_update_model_card_with_valid_json(self): mct = model_card_toolkit.ModelCardToolkit(output_dir=self.tmpdir) valid_model_card = mct.scaffold_assets() valid_model_card.schema_version = '0.0.1' valid_model_card.model_details.name = 'My Model' mct.update_model_card_json(valid_model_card) json_path = os.path.join(self.tmpdir, 'data/model_card.json') with open(json_path) as f: self.assertEqual( json.loads(f.read()), valid_model_card.to_dict(), )
def test_export_format_with_customized_template_and_output_name(self): mct = model_card_toolkit.ModelCardToolkit(output_dir=self.tmpdir) mc = mct.scaffold_assets() mc.model_details.name = 'My Model' mct.update_model_card(mc) template_path = os.path.join( self.tmpdir, 'template/html/default_template.html.jinja') output_file = 'my_model_card.html' result = mct.export_format(template_path=template_path, output_file=output_file) model_card_path = os.path.join(self.tmpdir, 'model_cards', output_file) self.assertTrue(os.path.exists(model_card_path)) with open(model_card_path) as f: content = f.read() self.assertEqual(content, result) self.assertTrue(content.startswith('<!DOCTYPE html>')) self.assertIn('My Model', content)
def test_export_format(self): store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) mct = model_card_toolkit.ModelCardToolkit( output_dir=self.tmpdir, mlmd_store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI) model_card = mct.scaffold_assets() model_card.schema_version = '0.0.1' model_card.model_details.name = 'My Model' mct.update_model_card_json(model_card) result = mct.export_format() model_card_path = os.path.join(self.tmpdir, 'model_cards/model_card.html') self.assertTrue(os.path.exists(model_card_path)) with open(model_card_path) as f: content = f.read() self.assertEqual(content, result) self.assertTrue(content.startswith('<!DOCTYPE html>')) self.assertIn('My Model', content)
def test_scaffold_assets_with_store(self, mock_annotate_data_stats, mock_annotate_eval_results): num_stat_artifacts = 2 num_eval_artifacts = 1 output_dir = self.tmpdir store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) mct = model_card_toolkit.ModelCardToolkit( output_dir=output_dir, mlmd_source=src.MlmdSource( store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI)) mc = mct.scaffold_assets() self.assertIsNotNone(mc.model_details.name) self.assertIsNotNone(mc.model_details.version.name) self.assertIn('default_template.html.jinja', os.listdir(os.path.join(output_dir, 'template/html'))) self.assertIn('default_template.md.jinja', os.listdir(os.path.join(output_dir, 'template/md'))) self.assertEqual(mock_annotate_data_stats.call_count, num_stat_artifacts) self.assertEqual(mock_annotate_eval_results.call_count, num_eval_artifacts)
def test_export_format_before_scaffold_assets(self): with self.assertRaises(ValueError): model_card_toolkit.ModelCardToolkit().export_format()
def test_scaffold_assets_with_source(self, output_file_format: str, artifacts: bool): if artifacts: connection_config = metadata_store_pb2.ConnectionConfig() connection_config.fake_database.SetInParent() mlmd_store = mlmd.MetadataStore(connection_config) else: mlmd_store = None train_dataset_name = 'Dataset-Split-train' train_features = ['feature_name1'] eval_dataset_name = 'Dataset-Split-eval' eval_features = ['feature_name2'] tfma_path = os.path.join(self.tmpdir, 'tfma') tfdv_path = os.path.join(self.tmpdir, 'tfdv') pushed_model_path = os.path.join(self.tmpdir, 'pushed_model') self._write_tfma(tfma_path, output_file_format, mlmd_store) self._write_tfdv(tfdv_path, train_dataset_name, train_features, eval_dataset_name, eval_features, mlmd_store) if artifacts: model_evaluation_artifacts = mlmd_store.get_artifacts_by_type( standard_artifacts.ModelEvaluation.TYPE_NAME) example_statistics_artifacts = mlmd_store.get_artifacts_by_type( standard_artifacts.ExampleStatistics.TYPE_NAME) pushed_model_artifact = standard_artifacts.PushedModel() pushed_model_artifact.uri = pushed_model_path tfma_src = src.TfmaSource( model_evaluation_artifacts=model_evaluation_artifacts, metrics_exclude=['average_loss']) tfdv_src = src.TfdvSource( example_statistics_artifacts=example_statistics_artifacts, features_include=['feature_name1']) model_src = src.ModelSource( pushed_model_artifact=pushed_model_artifact) else: tfma_src = src.TfmaSource(eval_result_paths=[tfma_path], metrics_exclude=['average_loss']) tfdv_src = src.TfdvSource(dataset_statistics_paths=[tfdv_path], features_include=['feature_name1']) model_src = src.ModelSource(pushed_model_path=pushed_model_path) mc = model_card_toolkit.ModelCardToolkit(source=src.Source( tfma=tfma_src, tfdv=tfdv_src, model=model_src)).scaffold_assets() with self.subTest(name='quantitative_analysis'): list_to_proto = lambda lst: [x.to_proto() for x in lst] expected_performance_metrics = [ model_card.PerformanceMetric( type='post_export_metrics/example_count', value='2.0') ] self.assertCountEqual( list_to_proto(mc.quantitative_analysis.performance_metrics), list_to_proto(expected_performance_metrics)) self.assertLen(mc.quantitative_analysis.graphics.collection, 1) with self.subTest(name='model_parameters.data'): self.assertLen(mc.model_parameters.data, 2) # train and eval for dataset in mc.model_parameters.data: for graphic in dataset.graphics.collection: self.assertIsNotNone( graphic.image, msg= f'No image found for graphic: {dataset.name} {graphic.name}' ) graphic.image = None # ignore graphic.image for below assertions self.assertIn( model_card.Dataset( name=train_dataset_name, graphics=model_card.GraphicsCollection(collection=[ model_card.Graphic(name='counts | feature_name1') ])), mc.model_parameters.data) self.assertNotIn( model_card.Dataset( name=eval_dataset_name, graphics=model_card.GraphicsCollection(collection=[ model_card.Graphic(name='counts | feature_name2') ])), mc.model_parameters.data) with self.subTest(name='model_details.path'): self.assertEqual(mc.model_details.path, pushed_model_path)
def test_update_model_card_with_invalid_schema_version(self): mct = model_card_toolkit.ModelCardToolkit() model_card_invalid_version = model_card_module.ModelCard( schema_version='100.0.0') with self.assertRaisesRegex(ValueError, 'Cannot find schema version'): mct.update_model_card_json(model_card_invalid_version)
def test_scaffold_assets_with_empty_source(self): model_card_toolkit.ModelCardToolkit( source=src.Source()).scaffold_assets()