def setUp(self): super(TestDataUtilsTest, self).setUp() tmp_db_path = os.path.join(absltest.get_default_test_tmpdir(), 'test.db') self.store = testdata_utils.get_tfx_pipeline_metadata_store( tmp_db_path) self.assertIsNotNone(self.store)
def test_scaffold_assets_with_store(self): output_dir = self.tmpdir store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) mct = model_card_toolkit.ModelCardToolkit( output_dir=output_dir, mlmd_store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI) mc = mct.scaffold_assets() self.assertIsNotNone(mc.model_details.name) self.assertIsNotNone(mc.model_details.version.name) self.assertNotEmpty(mc.quantitative_analysis.graphics.collection) self.assertIn( 'average_loss', { graphic.name for graphic in mc.quantitative_analysis.graphics.collection }) self.assertIn( 'post_export_metrics/example_count', { graphic.name for graphic in mc.quantitative_analysis.graphics.collection }) self.assertIn('default_template.html.jinja', os.listdir(os.path.join(output_dir, 'template/html'))) self.assertIn('default_template.md.jinja', os.listdir(os.path.join(output_dir, 'template/md')))
def test_export_format(self): store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) mct = model_card_toolkit.ModelCardToolkit( output_dir=self.tmpdir, mlmd_source=src.MlmdSource( store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI)) mc = mct.scaffold_assets() mc.model_details.name = 'My Model' mct.update_model_card(mc) result = mct.export_format() proto_path = os.path.join(self.tmpdir, 'data/model_card.proto') self.assertTrue(os.path.exists(proto_path)) with open(proto_path, 'rb') as f: model_card_proto = model_card_pb2.ModelCard() model_card_proto.ParseFromString(f.read()) self.assertEqual(model_card_proto.model_details.name, 'My Model') model_card_path = os.path.join(self.tmpdir, 'model_cards/model_card.html') self.assertTrue(os.path.exists(model_card_path)) with open(model_card_path) as f: content = f.read() self.assertEqual(content, result) self.assertTrue(content.startswith('<!DOCTYPE html>')) self.assertIn('My Model', content)
def test_get_stats_artifacts_for_model(self): store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path) got_stats = tfx_util.get_stats_artifacts_for_model( store, testdata_utils.TFX_0_21_MODEL_ARTIFACT_ID) got_stats_ids = [a.id for a in got_stats] self.assertCountEqual(got_stats_ids, [testdata_utils.TFX_0_21_STATS_ARTIFACT_ID])
def test_filter_features(self): store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path) stats = store.get_artifacts_by_id( [testdata_utils.TFX_0_21_STATS_ARTIFACT_ID]) dataset_stats = tfx_util.read_stats_protos(stats[-1].uri)[0].datasets[0] one_half_of_the_features = _DATASET_FEATURES[:27] the_other_half_of_the_features = _DATASET_FEATURES[27:] with self.subTest(name='features_include'): filtered_features = [ feature.path.step[0] for feature in tfx_util.filter_features( dataset_stats, features_include=one_half_of_the_features).features ] self.assertSameElements(one_half_of_the_features, filtered_features) with self.subTest(name='features_exclude'): filtered_features = [ feature.path.step[0] for feature in tfx_util.filter_features( dataset_stats, features_exclude=one_half_of_the_features).features ] self.assertSameElements(the_other_half_of_the_features, filtered_features) with self.subTest( name='both features_include and features_exclude (invalid)'): with self.assertRaises(ValueError): tfx_util.filter_features( dataset_stats, features_include=one_half_of_the_features, features_exclude=the_other_half_of_the_features) with self.subTest( name='neither features_include nor features_exclude (invalid)'): with self.assertRaises(ValueError): tfx_util.filter_features(dataset_stats)
def test_read_stats_proto_with_invalid_split(self): store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path) stats = store.get_artifacts_by_id( [testdata_utils.TFX_0_21_STATS_ARTIFACT_ID]) self.assertLen(stats, 1) actual_stats = tfx_util.read_stats_proto(stats[-1].uri, 'invalid_split') self.assertIsNone(actual_stats)
def test_read_stats_protos(self): store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path) stats = store.get_artifacts_by_id( [testdata_utils.TFX_0_21_STATS_ARTIFACT_ID]) self.assertLen(stats, 1) data_stats = tfx_util.read_stats_protos(stats[-1].uri) self.assertLen(data_stats, 2) # Split-eval, Split-train
def test_read_metrics_eval_result(self): store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) metrics = store.get_artifacts_by_id( testdata_utils.TFX_0_21_METRICS_ARTIFACT_IDS) eval_result = tfx_util.read_metrics_eval_result(metrics[-1].uri) self.assertIsNotNone(eval_result)
def test_init_with_store_no_model_uri(self): store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) with self.assertRaisesRegex( ValueError, 'If `mlmd_store` is set, `model_uri` should be set.'): model_card_toolkit.ModelCardToolkit(output_dir=self.tmpdir, mlmd_store=store)
def test_init_with_store_model_uri_not_found(self): store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) unknown_model = 'unknown_model' with self.assertRaisesRegex( ValueError, f'"{unknown_model}" cannot be found in the `store`'): model_card_toolkit.ModelCardToolkit(mlmd_source=src.MlmdSource( store=store, model_uri=unknown_model))
def test_read_stats_proto(self): store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path) stats = store.get_artifacts_by_id( [testdata_utils.TFX_0_21_STATS_ARTIFACT_ID]) self.assertLen(stats, 1) train_stats = tfx_util.read_stats_proto(stats[-1].uri, 'train') self.assertIsNotNone(train_stats) eval_stats = tfx_util.read_stats_proto(stats[-1].uri, 'eval') self.assertIsNotNone(eval_stats)
def test_generate_model_card_for_model(self): store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path) model_card = tfx_util.generate_model_card_for_model( store, testdata_utils.TFX_0_21_MODEL_ARTIFACT_ID) trainers = store.get_executions_by_id([testdata_utils.TFX_0_21_TRAINER_ID]) self.assertNotEmpty(trainers) model_details = model_card.model_details self.assertEqual(model_details.name, trainers[-1].properties['module_file'].string_value) self.assertEqual(model_details.version.name, trainers[-1].properties['checksum_md5'].string_value) self.assertIn( trainers[-1].properties['pipeline_name'].string_value, [reference.reference for reference in model_details.references]) datasets = store.get_artifacts_by_id( [testdata_utils.TFX_0_21_MODEL_DATASET_ID]) self.assertNotEmpty(datasets)
def test_export_format(self): store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) mct = model_card_toolkit.ModelCardToolkit( output_dir=self.tmpdir, mlmd_store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI) model_card = mct.scaffold_assets() model_card.schema_version = '0.0.1' model_card.model_details.name = 'My Model' mct.update_model_card_json(model_card) result = mct.export_format() model_card_path = os.path.join(self.tmpdir, 'model_cards/model_card.html') self.assertTrue(os.path.exists(model_card_path)) with open(model_card_path) as f: content = f.read() self.assertEqual(content, result) self.assertTrue(content.startswith('<!DOCTYPE html>')) self.assertIn('My Model', content)
def test_scaffold_assets_with_store(self, mock_annotate_data_stats, mock_annotate_eval_results): num_stat_artifacts = 2 num_eval_artifacts = 1 output_dir = self.tmpdir store = testdata_utils.get_tfx_pipeline_metadata_store( self.tmp_db_path) mct = model_card_toolkit.ModelCardToolkit( output_dir=output_dir, mlmd_source=src.MlmdSource( store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI)) mc = mct.scaffold_assets() self.assertIsNotNone(mc.model_details.name) self.assertIsNotNone(mc.model_details.version.name) self.assertIn('default_template.html.jinja', os.listdir(os.path.join(output_dir, 'template/html'))) self.assertIn('default_template.md.jinja', os.listdir(os.path.join(output_dir, 'template/md'))) self.assertEqual(mock_annotate_data_stats.call_count, num_stat_artifacts) self.assertEqual(mock_annotate_eval_results.call_count, num_eval_artifacts)
def test_get_stats_artifacts_for_model_with_invalid_model(self): store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path) with self.assertRaisesRegex(ValueError, 'not an instance of Model'): tfx_util.get_stats_artifacts_for_model( store, testdata_utils.TFX_0_21_MODEL_DATASET_ID)
def test_get_stats_artifacts_for_model_with_model_not_found(self): store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path) with self.assertRaisesRegex(ValueError, 'model_id cannot be found'): model = metadata_store_pb2.Artifact() tfx_util.get_stats_artifacts_for_model(store, model.id)