コード例 #1
0
 def setUp(self):
     super(TestDataUtilsTest, self).setUp()
     tmp_db_path = os.path.join(absltest.get_default_test_tmpdir(),
                                'test.db')
     self.store = testdata_utils.get_tfx_pipeline_metadata_store(
         tmp_db_path)
     self.assertIsNotNone(self.store)
コード例 #2
0
 def test_scaffold_assets_with_store(self):
     output_dir = self.tmpdir
     store = testdata_utils.get_tfx_pipeline_metadata_store(
         self.tmp_db_path)
     mct = model_card_toolkit.ModelCardToolkit(
         output_dir=output_dir,
         mlmd_store=store,
         model_uri=testdata_utils.TFX_0_21_MODEL_URI)
     mc = mct.scaffold_assets()
     self.assertIsNotNone(mc.model_details.name)
     self.assertIsNotNone(mc.model_details.version.name)
     self.assertNotEmpty(mc.quantitative_analysis.graphics.collection)
     self.assertIn(
         'average_loss', {
             graphic.name
             for graphic in mc.quantitative_analysis.graphics.collection
         })
     self.assertIn(
         'post_export_metrics/example_count', {
             graphic.name
             for graphic in mc.quantitative_analysis.graphics.collection
         })
     self.assertIn('default_template.html.jinja',
                   os.listdir(os.path.join(output_dir, 'template/html')))
     self.assertIn('default_template.md.jinja',
                   os.listdir(os.path.join(output_dir, 'template/md')))
コード例 #3
0
    def test_export_format(self):
        store = testdata_utils.get_tfx_pipeline_metadata_store(
            self.tmp_db_path)
        mct = model_card_toolkit.ModelCardToolkit(
            output_dir=self.tmpdir,
            mlmd_source=src.MlmdSource(
                store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI))
        mc = mct.scaffold_assets()
        mc.model_details.name = 'My Model'
        mct.update_model_card(mc)
        result = mct.export_format()

        proto_path = os.path.join(self.tmpdir, 'data/model_card.proto')
        self.assertTrue(os.path.exists(proto_path))
        with open(proto_path, 'rb') as f:
            model_card_proto = model_card_pb2.ModelCard()
            model_card_proto.ParseFromString(f.read())
            self.assertEqual(model_card_proto.model_details.name, 'My Model')
        model_card_path = os.path.join(self.tmpdir,
                                       'model_cards/model_card.html')
        self.assertTrue(os.path.exists(model_card_path))
        with open(model_card_path) as f:
            content = f.read()
            self.assertEqual(content, result)
            self.assertTrue(content.startswith('<!DOCTYPE html>'))
            self.assertIn('My Model', content)
コード例 #4
0
 def test_get_stats_artifacts_for_model(self):
   store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path)
   got_stats = tfx_util.get_stats_artifacts_for_model(
       store, testdata_utils.TFX_0_21_MODEL_ARTIFACT_ID)
   got_stats_ids = [a.id for a in got_stats]
   self.assertCountEqual(got_stats_ids,
                         [testdata_utils.TFX_0_21_STATS_ARTIFACT_ID])
コード例 #5
0
  def test_filter_features(self):
    store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path)
    stats = store.get_artifacts_by_id(
        [testdata_utils.TFX_0_21_STATS_ARTIFACT_ID])
    dataset_stats = tfx_util.read_stats_protos(stats[-1].uri)[0].datasets[0]

    one_half_of_the_features = _DATASET_FEATURES[:27]
    the_other_half_of_the_features = _DATASET_FEATURES[27:]

    with self.subTest(name='features_include'):
      filtered_features = [
          feature.path.step[0] for feature in tfx_util.filter_features(
              dataset_stats, features_include=one_half_of_the_features).features
      ]
      self.assertSameElements(one_half_of_the_features, filtered_features)
    with self.subTest(name='features_exclude'):
      filtered_features = [
          feature.path.step[0] for feature in tfx_util.filter_features(
              dataset_stats, features_exclude=one_half_of_the_features).features
      ]
      self.assertSameElements(the_other_half_of_the_features, filtered_features)
    with self.subTest(
        name='both features_include and features_exclude (invalid)'):
      with self.assertRaises(ValueError):
        tfx_util.filter_features(
            dataset_stats,
            features_include=one_half_of_the_features,
            features_exclude=the_other_half_of_the_features)
    with self.subTest(
        name='neither features_include nor features_exclude (invalid)'):
      with self.assertRaises(ValueError):
        tfx_util.filter_features(dataset_stats)
コード例 #6
0
 def test_read_stats_proto_with_invalid_split(self):
   store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path)
   stats = store.get_artifacts_by_id(
       [testdata_utils.TFX_0_21_STATS_ARTIFACT_ID])
   self.assertLen(stats, 1)
   actual_stats = tfx_util.read_stats_proto(stats[-1].uri, 'invalid_split')
   self.assertIsNone(actual_stats)
コード例 #7
0
 def test_read_stats_protos(self):
   store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path)
   stats = store.get_artifacts_by_id(
       [testdata_utils.TFX_0_21_STATS_ARTIFACT_ID])
   self.assertLen(stats, 1)
   data_stats = tfx_util.read_stats_protos(stats[-1].uri)
   self.assertLen(data_stats, 2)  # Split-eval, Split-train
コード例 #8
0
 def test_read_metrics_eval_result(self):
     store = testdata_utils.get_tfx_pipeline_metadata_store(
         self.tmp_db_path)
     metrics = store.get_artifacts_by_id(
         testdata_utils.TFX_0_21_METRICS_ARTIFACT_IDS)
     eval_result = tfx_util.read_metrics_eval_result(metrics[-1].uri)
     self.assertIsNotNone(eval_result)
コード例 #9
0
 def test_init_with_store_no_model_uri(self):
     store = testdata_utils.get_tfx_pipeline_metadata_store(
         self.tmp_db_path)
     with self.assertRaisesRegex(
             ValueError,
             'If `mlmd_store` is set, `model_uri` should be set.'):
         model_card_toolkit.ModelCardToolkit(output_dir=self.tmpdir,
                                             mlmd_store=store)
コード例 #10
0
 def test_init_with_store_model_uri_not_found(self):
     store = testdata_utils.get_tfx_pipeline_metadata_store(
         self.tmp_db_path)
     unknown_model = 'unknown_model'
     with self.assertRaisesRegex(
             ValueError,
             f'"{unknown_model}" cannot be found in the `store`'):
         model_card_toolkit.ModelCardToolkit(mlmd_source=src.MlmdSource(
             store=store, model_uri=unknown_model))
コード例 #11
0
 def test_read_stats_proto(self):
   store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path)
   stats = store.get_artifacts_by_id(
       [testdata_utils.TFX_0_21_STATS_ARTIFACT_ID])
   self.assertLen(stats, 1)
   train_stats = tfx_util.read_stats_proto(stats[-1].uri, 'train')
   self.assertIsNotNone(train_stats)
   eval_stats = tfx_util.read_stats_proto(stats[-1].uri, 'eval')
   self.assertIsNotNone(eval_stats)
コード例 #12
0
  def test_generate_model_card_for_model(self):
    store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path)
    model_card = tfx_util.generate_model_card_for_model(
        store, testdata_utils.TFX_0_21_MODEL_ARTIFACT_ID)
    trainers = store.get_executions_by_id([testdata_utils.TFX_0_21_TRAINER_ID])
    self.assertNotEmpty(trainers)
    model_details = model_card.model_details
    self.assertEqual(model_details.name,
                     trainers[-1].properties['module_file'].string_value)
    self.assertEqual(model_details.version.name,
                     trainers[-1].properties['checksum_md5'].string_value)
    self.assertIn(
        trainers[-1].properties['pipeline_name'].string_value,
        [reference.reference for reference in model_details.references])

    datasets = store.get_artifacts_by_id(
        [testdata_utils.TFX_0_21_MODEL_DATASET_ID])
    self.assertNotEmpty(datasets)
コード例 #13
0
    def test_export_format(self):
        store = testdata_utils.get_tfx_pipeline_metadata_store(
            self.tmp_db_path)
        mct = model_card_toolkit.ModelCardToolkit(
            output_dir=self.tmpdir,
            mlmd_store=store,
            model_uri=testdata_utils.TFX_0_21_MODEL_URI)
        model_card = mct.scaffold_assets()
        model_card.schema_version = '0.0.1'
        model_card.model_details.name = 'My Model'
        mct.update_model_card_json(model_card)
        result = mct.export_format()

        model_card_path = os.path.join(self.tmpdir,
                                       'model_cards/model_card.html')
        self.assertTrue(os.path.exists(model_card_path))
        with open(model_card_path) as f:
            content = f.read()
            self.assertEqual(content, result)
            self.assertTrue(content.startswith('<!DOCTYPE html>'))
            self.assertIn('My Model', content)
コード例 #14
0
 def test_scaffold_assets_with_store(self, mock_annotate_data_stats,
                                     mock_annotate_eval_results):
     num_stat_artifacts = 2
     num_eval_artifacts = 1
     output_dir = self.tmpdir
     store = testdata_utils.get_tfx_pipeline_metadata_store(
         self.tmp_db_path)
     mct = model_card_toolkit.ModelCardToolkit(
         output_dir=output_dir,
         mlmd_source=src.MlmdSource(
             store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI))
     mc = mct.scaffold_assets()
     self.assertIsNotNone(mc.model_details.name)
     self.assertIsNotNone(mc.model_details.version.name)
     self.assertIn('default_template.html.jinja',
                   os.listdir(os.path.join(output_dir, 'template/html')))
     self.assertIn('default_template.md.jinja',
                   os.listdir(os.path.join(output_dir, 'template/md')))
     self.assertEqual(mock_annotate_data_stats.call_count,
                      num_stat_artifacts)
     self.assertEqual(mock_annotate_eval_results.call_count,
                      num_eval_artifacts)
コード例 #15
0
 def test_get_stats_artifacts_for_model_with_invalid_model(self):
   store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path)
   with self.assertRaisesRegex(ValueError, 'not an instance of Model'):
     tfx_util.get_stats_artifacts_for_model(
         store, testdata_utils.TFX_0_21_MODEL_DATASET_ID)
コード例 #16
0
 def test_get_stats_artifacts_for_model_with_model_not_found(self):
   store = testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path)
   with self.assertRaisesRegex(ValueError, 'model_id cannot be found'):
     model = metadata_store_pb2.Artifact()
     tfx_util.get_stats_artifacts_for_model(store, model.id)