示例#1
0
 def test_copy_from_proto_with_invalid_proto(self):
     owner = model_card.Owner()
     wrong_proto = model_card_pb2.Version()
     with self.assertRaisesRegex(
             TypeError,
             "<class 'model_card_toolkit.proto.model_card_pb2.Owner'> is expected. "
             "However <class 'model_card_toolkit.proto.model_card_pb2.Version'> is "
             "provided."):
         owner.copy_from_proto(wrong_proto)
示例#2
0
    def test_merge_from_proto_sucess(self):
        # Test fields convert.
        owner = model_card.Owner(name="my_name1")
        owner_proto = model_card_pb2.Owner(contact="my_contact1")
        owner.merge_from_proto(owner_proto)
        self.assertEqual(
            owner, model_card.Owner(name="my_name1", contact="my_contact1"))

        # Test message convert.
        model_details = model_card.ModelDetails(
            owners=[model_card.Owner(name="my_name1")])
        model_details_proto = model_card_pb2.ModelDetails(owners=[
            model_card_pb2.Owner(name="my_name2", contact="my_contact2")
        ])
        model_details.merge_from_proto(model_details_proto)
        self.assertEqual(
            model_details,
            model_card.ModelDetails(owners=[
                model_card.Owner(name="my_name1"),
                model_card.Owner(name="my_name2", contact="my_contact2")
            ]))
示例#3
0
    def test_to_proto_sucess(self):
        # Test fields convert.
        owner = model_card.Owner()
        self.assertEqual(owner.to_proto(), model_card_pb2.Owner())
        owner.name = "my_name"
        self.assertEqual(owner.to_proto(),
                         model_card_pb2.Owner(name="my_name"))
        owner.contact = "my_contact"
        self.assertEqual(
            owner.to_proto(),
            model_card_pb2.Owner(name="my_name", contact="my_contact"))

        # Test message convert.
        model_details = model_card.ModelDetails(
            owners=[model_card.Owner(name="my_name", contact="my_contact")])
        self.assertEqual(
            model_details.to_proto(),
            model_card_pb2.ModelDetails(owners=[
                model_card_pb2.Owner(name="my_name", contact="my_contact")
            ],
                                        version=model_card_pb2.Version()))
示例#4
0
 def test_to_proto_with_invalid_field(self):
     owner = model_card.Owner()
     owner.wrong_field = "wrong"
     with self.assertRaisesRegex(ValueError,
                                 "has no such field named 'wrong_field'."):
         owner.to_proto()
示例#5
0
 def test_merge_from_proto_with_invalid_proto(self):
     owner = model_card.Owner()
     wrong_proto = model_card_pb2.Version()
     with self.assertRaisesRegex(TypeError,
                                 ".*expected .*Owner got .*Version.*"):
         owner.merge_from_proto(wrong_proto)