def test_update_model_card_with_valid_model_card_as_proto(self):
        valid_model_card = model_card_pb2.ModelCard()
        valid_model_card.model_details.name = 'My Model'

        mct = model_card_toolkit.ModelCardToolkit(output_dir=self.tmpdir)
        mct.update_model_card(valid_model_card)
        proto_path = os.path.join(self.tmpdir, 'data/model_card.proto')

        model_card_proto = model_card_pb2.ModelCard()
        with open(proto_path, 'rb') as f:
            model_card_proto.ParseFromString(f.read())
        self.assertEqual(model_card_proto, valid_model_card)
Beispiel #2
0
    def test_merge_from_proto_and_to_proto_with_all_fields(self):
        want_proto = text_format.Parse(_FULL_PROTO, model_card_pb2.ModelCard())
        model_card_py = model_card.ModelCard()
        model_card_py.merge_from_proto(want_proto)
        got_proto = model_card_py.to_proto()

        self.assertEqual(want_proto, got_proto)
    def test_export_format(self):
        store = testdata_utils.get_tfx_pipeline_metadata_store(
            self.tmp_db_path)
        mct = model_card_toolkit.ModelCardToolkit(
            output_dir=self.tmpdir,
            mlmd_source=src.MlmdSource(
                store=store, model_uri=testdata_utils.TFX_0_21_MODEL_URI))
        mc = mct.scaffold_assets()
        mc.model_details.name = 'My Model'
        mct.update_model_card(mc)
        result = mct.export_format()

        proto_path = os.path.join(self.tmpdir, 'data/model_card.proto')
        self.assertTrue(os.path.exists(proto_path))
        with open(proto_path, 'rb') as f:
            model_card_proto = model_card_pb2.ModelCard()
            model_card_proto.ParseFromString(f.read())
            self.assertEqual(model_card_proto.model_details.name, 'My Model')
        model_card_path = os.path.join(self.tmpdir,
                                       'model_cards/model_card.html')
        self.assertTrue(os.path.exists(model_card_path))
        with open(model_card_path) as f:
            content = f.read()
            self.assertEqual(content, result)
            self.assertTrue(content.startswith('<!DOCTYPE html>'))
            self.assertIn('My Model', content)
Beispiel #4
0
 def _read_proto_file(self, path: str) -> Optional[ModelCard]:
   """Read serialized model card proto from the path."""
   if not os.path.exists(path):
     return None
   model_card_proto = model_card_pb2.ModelCard()
   with open(path, 'rb') as f:
     model_card_proto.ParseFromString(f.read())
   return ModelCard().copy_from_proto(model_card_proto)
Beispiel #5
0
    def test_from_json_to_proto(self):
        model_card_proto = text_format.Parse(_FULL_PROTO,
                                             model_card_pb2.ModelCard())

        model_card_json = json.loads(_FULL_JSON)
        model_card_py = model_card.ModelCard()._from_json(model_card_json)
        model_card_json2proto = model_card_py.to_proto()

        self.assertEqual(model_card_proto, model_card_json2proto)
Beispiel #6
0
    def test_from_proto_to_json(self):
        model_card_proto = text_format.Parse(_FULL_PROTO,
                                             model_card_pb2.ModelCard())
        model_card_py = model_card.ModelCard()

        # Use merge_from_proto.
        self.assertJsonEqual(
            _FULL_JSON,
            model_card_py.merge_from_proto(model_card_proto).to_json())
        # Use copy_from_proto
        self.assertJsonEqual(
            _FULL_JSON,
            model_card_py.copy_from_proto(model_card_proto).to_json())
 def _read_proto_file(self, path: Text) -> ModelCard:
     """Read serialized model card proto from the path."""
     model_card_proto = model_card_pb2.ModelCard()
     with open(path, 'rb') as f:
         model_card_proto.ParseFromString(f.read())
     return ModelCard().copy_from_proto(model_card_proto)