예제 #1
0
 def testArtifactTypeRoundTrip(self):
     mlmd_artifact_type = standard_artifacts.Examples._get_artifact_type()
     self.assertIs(
         standard_artifacts.Examples,
         artifact_utils.get_artifact_type_class(mlmd_artifact_type))
     mlmd_artifact_type = _MyArtifact._get_artifact_type()
     self.assertIs(
         _MyArtifact,
         artifact_utils.get_artifact_type_class(mlmd_artifact_type))
예제 #2
0
 def testArtifactTypeRoundTrip(self):
     mlmd_artifact_type = standard_artifacts.Examples._get_artifact_type()
     self.assertIs(
         standard_artifacts.Examples,
         artifact_utils.get_artifact_type_class(mlmd_artifact_type))
     mlmd_artifact_type = _MyArtifact._get_artifact_type()
     # Test that the ID is ignored for type comparison purposes during
     # deserialization.
     mlmd_artifact_type.id = 123
     self.assertIs(
         _MyArtifact,
         artifact_utils.get_artifact_type_class(mlmd_artifact_type))
예제 #3
0
    def testArtifactTypeRoundTripUnknownArtifactClass(self, mock_warning):
        mlmd_artifact_type = copy.deepcopy(
            standard_artifacts.Examples._get_artifact_type())
        self.assertIs(
            standard_artifacts.Examples,
            artifact_utils.get_artifact_type_class(mlmd_artifact_type))
        mlmd_artifact_type.name = 'UnknownTypeName'

        reconstructed_class = artifact_utils.get_artifact_type_class(
            mlmd_artifact_type)
        mock_warning.assert_called_once()

        self.assertIsNot(standard_artifacts.Examples, reconstructed_class)
        self.assertTrue(issubclass(reconstructed_class, artifact.Artifact))
        self.assertEqual('UnknownTypeName', reconstructed_class.TYPE_NAME)
        self.assertEqual(mlmd_artifact_type,
                         reconstructed_class._get_artifact_type())
예제 #4
0
파일: channel.py 프로젝트: yifanmai/tfx
 def from_json_dict(cls, dict_data: Dict[Text, Any]) -> Any:
     artifact_type = metadata_store_pb2.ArtifactType()
     json_format.Parse(json.dumps(dict_data['type']), artifact_type)
     type_cls = artifact_utils.get_artifact_type_class(artifact_type)
     artifacts = list(
         Artifact.from_json_dict(a) for a in dict_data['artifacts'])
     producer_component_id = dict_data.get('producer_component_id', None)
     output_key = dict_data.get('output_key', None)
     return Channel(type=type_cls,
                    artifacts=artifacts,
                    producer_component_id=producer_component_id,
                    output_key=output_key)