def create_artifact_with_type( self, artifact: metadata_store_pb2.Artifact, artifact_type: metadata_store_pb2.ArtifactType) -> int: """Creates an artifact with a type. This first gets the type (or creates it if it does not exist), and then puts the artifact into the database with that type. The type_id should not be specified in the artifact (it is ignored). Note that this is not a transaction! 1. First, the type is created as a transaction. 2. Then the artifact is created as a transaction. Args: artifact: the artifact to create (no id or type_id) artifact_type: the type of the new artifact (no id) Returns: the artifact ID of the resulting type. Raises: InvalidArgument: if the type is not the same as one with the same name already in the database. """ type_id = self.put_artifact_type(artifact_type) artifact_copy = metadata_store_pb2.Artifact() artifact_copy.CopyFrom(artifact) artifact_copy.type_id = type_id [artifact_id] = self.put_artifacts([artifact_copy]) return artifact_id
def testArtifactAndEventPairs(self): example = standard_artifacts.Examples() example.uri = 'example' example.id = 1 expected_artifact = metadata_store_pb2.Artifact() text_format.Parse( """ id: 1 type_id: 1 uri: 'example'""", expected_artifact) expected_event = metadata_store_pb2.Event() text_format.Parse( """ path { steps { key: 'example' } steps { index: 0 } } type: INPUT""", expected_event) with metadata.Metadata(connection_config=self._connection_config) as m: result = execution_lib._create_artifact_and_event_pairs( m, { 'example': [example], }, metadata_store_pb2.Event.INPUT) self.assertCountEqual([(expected_artifact, expected_event)], result)
def test_put_artifacts_get_artifacts_by_uri(self): store = _get_metadata_store() artifact_type = _create_example_artifact_type() type_id = store.put_artifact_type(artifact_type) want_artifact = metadata_store_pb2.Artifact() want_artifact.type_id = type_id want_artifact.uri = "test_uri" other_artifact = metadata_store_pb2.Artifact() other_artifact.uri = "other_uri" other_artifact.type_id = type_id [want_artifact_id, _] = store.put_artifacts([want_artifact, other_artifact]) artifact_result = store.get_artifacts_by_uri(want_artifact.uri) self.assertLen(artifact_result, 1) self.assertEqual(artifact_result[0].id, want_artifact_id)
def test_put_artifacts_get_artifact_by_type_and_name(self): # Prepare test data. store = _get_metadata_store() artifact_type = _create_example_artifact_type(self._get_test_type_name()) type_id = store.put_artifact_type(artifact_type) artifact = metadata_store_pb2.Artifact() artifact.type_id = type_id artifact.name = self._get_test_type_name() [artifact_id] = store.put_artifacts([artifact]) # Test Artifact found case. got_artifact = store.get_artifact_by_type_and_name(artifact_type.name, artifact.name) self.assertEqual(got_artifact.id, artifact_id) self.assertEqual(got_artifact.type_id, type_id) self.assertEqual(got_artifact.name, artifact.name) # Test Artifact not found cases. empty_artifact = store.get_artifact_by_type_and_name( "random_name", artifact.name) self.assertIsNone(empty_artifact) empty_artifact = store.get_artifact_by_type_and_name( artifact_type.name, "random_name") self.assertIsNone(empty_artifact) empty_artifact = store.get_artifact_by_type_and_name( "random_name", "random_name") self.assertIsNone(empty_artifact)
def deserialize_artifact( artifact_type: metadata_store_pb2.ArtifactType, artifact: Optional[metadata_store_pb2.Artifact] = None) -> Artifact: """Reconstruct Artifact object from MLMD proto descriptors. Internal method, no backwards compatibility guarantees. Args: artifact_type: A metadata_store_pb2.ArtifactType proto object describing the type of the artifact. artifact: A metadata_store_pb2.Artifact proto object describing the contents of the artifact. If not provided, an Artifact of the desired type with empty contents is created. Returns: Artifact subclass object for the given MLMD proto descriptors. """ # Validate inputs. if not isinstance(artifact_type, metadata_store_pb2.ArtifactType): raise ValueError( ('Expected metadata_store_pb2.ArtifactType for artifact_type, got %s ' 'instead') % (artifact_type,)) if artifact and not isinstance(artifact, metadata_store_pb2.Artifact): raise ValueError( ('Expected metadata_store_pb2.Artifact for artifact, got %s ' 'instead') % (artifact,)) # Get the artifact's class and construct the Artifact object. artifact_cls = get_artifact_type_class(artifact_type) result = artifact_cls() result.artifact_type.CopyFrom(artifact_type) result.set_mlmd_artifact(artifact or metadata_store_pb2.Artifact()) return result
def from_json_dict(cls, dict_data: Dict[Text, Any]) -> Any: module_name = dict_data['__artifact_class_module__'] class_name = dict_data['__artifact_class_name__'] artifact = metadata_store_pb2.Artifact() artifact_type = metadata_store_pb2.ArtifactType() json_format.Parse(json.dumps(dict_data['artifact']), artifact) json_format.Parse(json.dumps(dict_data['artifact_type']), artifact_type) # First, try to resolve the specific class used for the artifact; if this # is not possible, use a generic artifact.Artifact object. result = None try: artifact_cls = getattr(importlib.import_module(module_name), class_name) # If the artifact type is the base Artifact class, do not construct the # object here since that constructor requires the mlmd_artifact_type # argument. if artifact_cls != Artifact: result = artifact_cls() except (AttributeError, ImportError, ValueError): absl.logging.warning(( 'Could not load artifact class %s.%s; using fallback deserialization ' 'for the relevant artifact. This behavior may not be supported in ' 'the future; please make sure that any artifact classes can be ' 'imported within your container or environment.') % (module_name, class_name)) if not result: result = Artifact(mlmd_artifact_type=artifact_type) result.set_mlmd_artifact_type(artifact_type) result.set_mlmd_artifact(artifact) return result
def __init__(self, type_name: Text, split: Text = ''): """Construct an instance of TfxType. Each instance of TfxTypes wraps an Artifact and its type internally. When first created, the artifact will have an empty URI (which will be filled by orchestration system before first usage). Args: type_name: Name of underlying ArtifactType. split: Which split this instance of articact maps to. """ artifact_type = metadata_store_pb2.ArtifactType() artifact_type.name = type_name artifact_type.properties['type_name'] = metadata_store_pb2.STRING # This is a temporary solution due to b/123435989. artifact_type.properties['name'] = metadata_store_pb2.STRING # This indicates the state of an artifact. A state can be any of the # followings: PENDING, PUBLISHED, MISSING, DELETING, DELETED # TODO(ruoyu): Maybe switch to artifact top-level state if it's supported. artifact_type.properties['state'] = metadata_store_pb2.STRING # Span number of an artifact. For the same artifact type produced by the # same executor, this number should always increase. artifact_type.properties['span'] = metadata_store_pb2.INT # Comma separated splits recognized. Empty string means artifact has no # split. artifact_type.properties['split'] = metadata_store_pb2.STRING self.artifact_type = artifact_type artifact = metadata_store_pb2.Artifact() artifact.properties['type_name'].string_value = type_name artifact.properties['split'].string_value = split self.artifact = artifact
def test_put_duplicated_attributions_and_empty_associations(self): store = _get_metadata_store() context_type = _create_example_context_type() context_type_id = store.put_context_type(context_type) want_context = metadata_store_pb2.Context() want_context.type_id = context_type_id want_context.name = "context" [context_id] = store.put_contexts([want_context]) want_context.id = context_id artifact_type = _create_example_artifact_type() artifact_type_id = store.put_artifact_type(artifact_type) want_artifact = metadata_store_pb2.Artifact() want_artifact.type_id = artifact_type_id want_artifact.uri = "testuri" [artifact_id] = store.put_artifacts([want_artifact]) want_artifact.id = artifact_id attribution = metadata_store_pb2.Attribution() attribution.artifact_id = want_artifact.id attribution.context_id = want_context.id store.put_attributions_and_associations([attribution, attribution], []) got_contexts = store.get_contexts_by_artifact(want_artifact.id) self.assertLen(got_contexts, 1) self.assertEqual(got_contexts[0].id, want_context.id) self.assertEqual(got_contexts[0].name, want_context.name) got_arifacts = store.get_artifacts_by_context(want_context.id) self.assertLen(got_arifacts, 1) self.assertEqual(got_arifacts[0].uri, want_artifact.uri) self.assertEmpty(store.get_executions_by_context(want_context.id))
def test_put_events_get_events(self): store = _get_metadata_store() execution_type = metadata_store_pb2.ExecutionType() execution_type.name = "execution_type" execution_type_id = store.put_execution_type(execution_type) execution = metadata_store_pb2.Execution() execution.type_id = execution_type_id [execution_id] = store.put_executions([execution]) artifact_type = metadata_store_pb2.ArtifactType() artifact_type.name = "artifact_type" artifact_type_id = store.put_artifact_type(artifact_type) artifact = metadata_store_pb2.Artifact() artifact.type_id = artifact_type_id [artifact_id] = store.put_artifacts([artifact]) event = metadata_store_pb2.Event() event.type = metadata_store_pb2.Event.DECLARED_OUTPUT event.artifact_id = artifact_id event.execution_id = execution_id store.put_events([event]) [event_result] = store.get_events_by_artifact_ids([artifact_id]) self.assertEqual(event_result.artifact_id, artifact_id) self.assertEqual(event_result.execution_id, execution_id) self.assertEqual(event_result.type, metadata_store_pb2.Event.DECLARED_OUTPUT) [event_result_2] = store.get_events_by_execution_ids([execution_id]) self.assertEqual(event_result_2.artifact_id, artifact_id) self.assertEqual(event_result_2.execution_id, execution_id) self.assertEqual(event_result_2.type, metadata_store_pb2.Event.DECLARED_OUTPUT)
def _write_tfdv(self, tfdv_path: str, train_dataset_name: str, train_features: List[str], eval_dataset_name: str, eval_features: List[str], store: Optional[mlmd.MetadataStore] = None): a_bucket = statistics_pb2.RankHistogram.Bucket(low_rank=0, high_rank=0, label='a', sample_count=4.0) b_bucket = statistics_pb2.RankHistogram.Bucket(low_rank=1, high_rank=1, label='b', sample_count=3.0) c_bucket = statistics_pb2.RankHistogram.Bucket(low_rank=2, high_rank=2, label='c', sample_count=2.0) train_stats = statistics_pb2.DatasetFeatureStatistics() train_stats.name = train_dataset_name for feature in train_features: train_stats.features.add() train_stats.features[0].name = feature train_stats.features[0].string_stats.rank_histogram.buckets.extend( [a_bucket, b_bucket, c_bucket]) train_stats_list = statistics_pb2.DatasetFeatureStatisticsList( datasets=[train_stats]) train_stats_file = os.path.join(tfdv_path, 'Split-train', 'FeatureStats.pb') os.makedirs(os.path.dirname(train_stats_file), exist_ok=True) with open(train_stats_file, mode='wb') as f: f.write(train_stats_list.SerializeToString()) eval_stats = statistics_pb2.DatasetFeatureStatistics() eval_stats.name = eval_dataset_name for feature in eval_features: eval_stats.features.add() eval_stats.features[0].path.step.append(feature) eval_stats.features[0].string_stats.rank_histogram.buckets.extend( [a_bucket, b_bucket, c_bucket]) eval_stats_list = statistics_pb2.DatasetFeatureStatisticsList( datasets=[eval_stats]) eval_stats_file = os.path.join(tfdv_path, 'Split-eval', 'FeatureStats.pb') os.makedirs(os.path.dirname(eval_stats_file), exist_ok=True) with open(eval_stats_file, mode='wb') as f: f.write(eval_stats_list.SerializeToString()) if store: stats_type = metadata_store_pb2.ArtifactType() stats_type.name = standard_artifacts.ExampleStatistics.TYPE_NAME stats_type_id = store.put_artifact_type(stats_type) artifact = metadata_store_pb2.Artifact() artifact.uri = tfdv_path artifact.type_id = stats_type_id store.put_artifacts([artifact])
def _create_tfx_artifact(type_name, split): """Create an Artifact for TfxType.__init__().""" artifact = metadata_store_pb2.Artifact() # TODO(martinz): consider whether type_name needs to be hard-coded into the # artifact. artifact.properties['type_name'].string_value = type_name artifact.properties['split'].string_value = split return artifact
def test_put_artifacts_get_artifacts_by_type(self): store = _get_metadata_store() artifact_type = _create_example_artifact_type() type_id = store.put_artifact_type(artifact_type) artifact_type_2 = _create_example_artifact_type_2() type_id_2 = store.put_artifact_type(artifact_type_2) artifact_0 = metadata_store_pb2.Artifact() artifact_0.type_id = type_id artifact_0.properties["foo"].int_value = 3 artifact_0.properties["bar"].string_value = "Hello" artifact_1 = metadata_store_pb2.Artifact() artifact_1.type_id = type_id_2 [_, artifact_id_1] = store.put_artifacts([artifact_0, artifact_1]) artifact_result = store.get_artifacts_by_type(artifact_type_2.name) self.assertLen(artifact_result, 1) self.assertEqual(artifact_result[0].id, artifact_id_1)
def test_put_artifacts_get_artifacts_by_id_with_set(self): store = _get_metadata_store() artifact_type = _create_example_artifact_type(self._get_test_type_name()) type_id = store.put_artifact_type(artifact_type) artifact = metadata_store_pb2.Artifact() artifact.type_id = type_id [artifact_id] = store.put_artifacts([artifact]) [artifact_result] = store.get_artifacts_by_id({artifact_id}) self.assertEqual(artifact_result.type_id, artifact.type_id)
def testResolveInputArtifacts(self): # Create input splits. split1 = os.path.join(self._input_base_path, 'split1', 'data') io_utils.write_string_file(split1, 'testing') os.utime(split1, (0, 1)) split2 = os.path.join(self._input_base_path, 'split2', 'data') io_utils.write_string_file(split2, 'testing2') os.utime(split2, (0, 3)) # Mock artifact. artifacts = [] for i in [4, 3, 2, 1]: artifact = metadata_store_pb2.Artifact() artifact.id = i artifact.uri = self._input_base_path artifact.custom_properties['span'].string_value = '0' # Only odd ids will be matched if i % 2 == 1: artifact.custom_properties[ 'input_fingerprint'].string_value = 'split:s1,num_files:1,total_bytes:7,xor_checksum:1,sum_checksum:1\nsplit:s2,num_files:1,total_bytes:8,xor_checksum:3,sum_checksum:3' else: artifact.custom_properties[ 'input_fingerprint'].string_value = 'not_match' artifacts.append(artifact) # Create exec proterties. exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s1', pattern='split1/*'), example_gen_pb2.Input.Split(name='s2', pattern='split2/*') ]), preserving_proto_field_name=True), } # Cache not hit. self._mock_metadata.get_artifacts_by_uri.return_value = [artifacts[0]] self._mock_metadata.publish_artifacts.return_value = [artifacts[3]] updated_input_dict = self._example_gen_driver.resolve_input_artifacts( self._input_channels, exec_properties, None, None) self.assertEqual(1, len(updated_input_dict)) self.assertEqual(1, len(updated_input_dict['input_base'])) updated_input_base = updated_input_dict['input_base'][0] self.assertEqual(1, updated_input_base.id) self.assertEqual(self._input_base_path, updated_input_base.uri) # Cache hit. self._mock_metadata.get_artifacts_by_uri.return_value = artifacts self._mock_metadata.publish_artifacts.return_value = [] updated_input_dict = self._example_gen_driver.resolve_input_artifacts( self._input_channels, exec_properties, None, None) self.assertEqual(1, len(updated_input_dict)) self.assertEqual(1, len(updated_input_dict['input_base'])) updated_input_base = updated_input_dict['input_base'][0] self.assertEqual(3, updated_input_base.id) self.assertEqual(self._input_base_path, updated_input_base.uri)
def from_json_dict(cls, dict_data: Dict[Text, Any]) -> Any: artifact = metadata_store_pb2.Artifact() json_format.Parse(json.dumps(dict_data['artifact']), artifact) artifact_type = metadata_store_pb2.ArtifactType() json_format.Parse(json.dumps(dict_data['artifact_type']), artifact_type) result = Artifact(artifact_type.name) result.set_artifact_type(artifact_type) result.set_artifact(artifact) return result
def parse_from_json_dict(cls, d): """Creates a instance of TfxType from a json deserialized dict.""" artifact = metadata_store_pb2.Artifact() json_format.Parse(json.dumps(d['artifact']), artifact) artifact_type = metadata_store_pb2.ArtifactType() json_format.Parse(json.dumps(d['artifact_type']), artifact_type) result = TfxType(artifact_type.name, artifact.uri) result.set_artifact_type(artifact_type) result.set_artifact(artifact) return result
def test_create_artifact_with_type_get_artifacts_by_id(self): store = _get_metadata_store() artifact_type = _create_example_artifact_type() artifact = metadata_store_pb2.Artifact() artifact.properties["foo"].int_value = 3 artifact.properties["bar"].string_value = "Hello" artifact_id = store.create_artifact_with_type(artifact, artifact_type) [artifact_result] = store.get_artifacts_by_id([artifact_id]) self.assertEqual(artifact_result.properties["bar"].string_value, "Hello") self.assertEqual(artifact_result.properties["foo"].int_value, 3)
def _create_artifact_with_type(self, uri: str, type_name: str, property_types: dict = None, properties: dict = None, custom_properties: dict = None): artifact_type = self._get_or_create_artifact_type( type_name=type_name, properties=property_types) artifact = metadata_store_pb2.Artifact( uri=uri, type_id=artifact_type.id, properties=properties, custom_properties=custom_properties) artifact.id = self.store.put_artifacts([artifact])[0] return artifact
def test_update_artifact_get_artifact(self): store = _get_metadata_store() artifact_type = _create_example_artifact_type() type_id = store.put_artifact_type(artifact_type) artifact = metadata_store_pb2.Artifact() artifact.type_id = type_id artifact.properties["bar"].string_value = "Hello" [artifact_id] = store.put_artifacts([artifact]) artifact_2 = metadata_store_pb2.Artifact() artifact_2.CopyFrom(artifact) artifact_2.id = artifact_id artifact_2.properties["foo"].int_value = artifact_id artifact_2.properties["bar"].string_value = "Goodbye" [artifact_id_2] = store.put_artifacts([artifact_2]) self.assertEqual(artifact_id, artifact_id_2) [artifact_result] = store.get_artifacts_by_id([artifact_id]) self.assertEqual(artifact_result.properties["bar"].string_value, "Goodbye") self.assertEqual(artifact_result.properties["foo"].int_value, artifact_id)
def test_put_and_use_attributions_and_associations(self): store = _get_metadata_store() context_type = _create_example_context_type(self._get_test_type_name()) context_type_id = store.put_context_type(context_type) want_context = metadata_store_pb2.Context() want_context.type_id = context_type_id want_context.name = self._get_test_type_name() [context_id] = store.put_contexts([want_context]) want_context.id = context_id execution_type = _create_example_execution_type( self._get_test_type_name()) execution_type_id = store.put_execution_type(execution_type) want_execution = metadata_store_pb2.Execution() want_execution.type_id = execution_type_id want_execution.properties["foo"].int_value = 3 [execution_id] = store.put_executions([want_execution]) want_execution.id = execution_id artifact_type = _create_example_artifact_type( self._get_test_type_name()) artifact_type_id = store.put_artifact_type(artifact_type) want_artifact = metadata_store_pb2.Artifact() want_artifact.type_id = artifact_type_id want_artifact.uri = "testuri" [artifact_id] = store.put_artifacts([want_artifact]) want_artifact.id = artifact_id # insert attribution and association and test querying the relationship attribution = metadata_store_pb2.Attribution() attribution.artifact_id = want_artifact.id attribution.context_id = want_context.id association = metadata_store_pb2.Association() association.execution_id = want_execution.id association.context_id = want_context.id store.put_attributions_and_associations([attribution], [association]) # test querying the relationship got_contexts = store.get_contexts_by_artifact(want_artifact.id) self.assertLen(got_contexts, 1) self.assertEqual(got_contexts[0].id, want_context.id) self.assertEqual(got_contexts[0].name, want_context.name) got_arifacts = store.get_artifacts_by_context(want_context.id) self.assertLen(got_arifacts, 1) self.assertEqual(got_arifacts[0].uri, want_artifact.uri) got_executions = store.get_executions_by_context(want_context.id) self.assertLen(got_executions, 1) self.assertEqual(got_executions[0].properties["foo"], want_execution.properties["foo"]) got_contexts = store.get_contexts_by_execution(want_execution.id) self.assertLen(got_contexts, 1) self.assertEqual(got_contexts[0].id, want_context.id) self.assertEqual(got_contexts[0].name, want_context.name)
def test_log_invalid_artifacts_should_fail(self): store = metadata.Store(grpc_host=GRPC_HOST, grpc_port=GRPC_PORT) ws = metadata.Workspace(store=store, name="ws_1", description="a workspace for testing", labels={"n1": "v1"}) e = metadata.Execution(name="test execution", workspace=ws) artifact1 = ArtifactFixture( mlpb.Artifact(uri="gs://uri", custom_properties={ metadata._WORKSPACE_PROPERTY_NAME: mlpb.Value(string_value="ws1"), })) self.assertRaises(ValueError, e.log_input, artifact1) artifact2 = ArtifactFixture( mlpb.Artifact(uri="gs://uri", custom_properties={ metadata._RUN_PROPERTY_NAME: mlpb.Value(string_value="run1"), })) self.assertRaises(ValueError, e.log_output, artifact2)
def test_put_invalid_artifact(self): store = _get_metadata_store() artifact_type = _create_example_artifact_type(self._get_test_type_name()) artifact_type_id = store.put_artifact_type(artifact_type) artifact = metadata_store_pb2.Artifact() artifact.type_id = artifact_type_id artifact.uri = "testuri" # Create the Value message for "foo" but don't populate its value. artifact.properties["foo"] # pylint: disable=pointless-statement with self.assertRaisesRegex(errors.InvalidArgumentError, "Found unmatched property type: foo"): store.put_artifacts([artifact])
def test_put_artifacts_get_artifacts_by_id(self): store = _get_metadata_store() artifact_type = _create_example_artifact_type(self._get_test_type_name()) type_id = store.put_artifact_type(artifact_type) artifact = metadata_store_pb2.Artifact() artifact.type_id = type_id artifact.properties["foo"].int_value = 3 artifact.properties["bar"].string_value = "Hello" [artifact_id] = store.put_artifacts([artifact]) [artifact_result] = store.get_artifacts_by_id([artifact_id]) self.assertEqual(artifact_result.properties["bar"].string_value, "Hello") self.assertEqual(artifact_result.properties["foo"].int_value, 3)
def test_put_artifacts_get_artifacts(self): store = _get_metadata_store() artifact_type = _create_example_artifact_type() type_id = store.put_artifact_type(artifact_type) artifact_0 = metadata_store_pb2.Artifact() artifact_0.type_id = type_id artifact_0.properties["foo"].int_value = 3 artifact_0.properties["bar"].string_value = "Hello" artifact_1 = metadata_store_pb2.Artifact() artifact_1.type_id = type_id [artifact_id_0, artifact_id_1] = store.put_artifacts([artifact_0, artifact_1]) artifact_result = store.get_artifacts() if artifact_result[0].id == artifact_id_0: [artifact_result_0, artifact_result_1] = artifact_result else: [artifact_result_1, artifact_result_0] = artifact_result self.assertEqual(artifact_result_0.id, artifact_id_0) self.assertEqual(artifact_result_0.properties["bar"].string_value, "Hello") self.assertEqual(artifact_result_0.properties["foo"].int_value, 3) self.assertEqual(artifact_result_1.id, artifact_id_1)
def mlpb_artifact(type_id, uri, workspace, name=None, version=None): properties = {} if name: properties["name"] = mlpb.Value(string_value=name) if version: properties["version"] = mlpb.Value(string_value=version) return mlpb.Artifact(uri=uri, type_id=type_id, properties=properties, custom_properties={ metadata._WORKSPACE_PROPERTY_NAME: mlpb.Value(string_value=workspace), })
def test_publish_execution(self): store = _get_metadata_store() execution_type = metadata_store_pb2.ExecutionType() execution_type.name = "execution_type" execution_type_id = store.put_execution_type(execution_type) execution = metadata_store_pb2.Execution() execution.type_id = execution_type_id artifact_type = metadata_store_pb2.ArtifactType() artifact_type.name = "artifact_type" artifact_type_id = store.put_artifact_type(artifact_type) input_artifact = metadata_store_pb2.Artifact() input_artifact.type_id = artifact_type_id output_artifact = metadata_store_pb2.Artifact() output_artifact.type_id = artifact_type_id output_event = metadata_store_pb2.Event() output_event.type = metadata_store_pb2.Event.DECLARED_INPUT execution_id, artifact_ids = store.put_execution( execution, [[input_artifact], [output_artifact, output_event]]) self.assertLen(artifact_ids, 2) events = store.get_events_by_execution_ids([execution_id]) self.assertLen(events, 1)
def from_json_dict(cls, dict_data: Dict[Text, Any]) -> Any: module_name = dict_data['__artifact_class_module__'] class_name = dict_data['__artifact_class_name__'] artifact_cls = getattr(importlib.import_module(module_name), class_name) artifact = metadata_store_pb2.Artifact() json_format.Parse(json.dumps(dict_data['artifact']), artifact) artifact_type = metadata_store_pb2.ArtifactType() json_format.Parse(json.dumps(dict_data['artifact_type']), artifact_type) result = artifact_cls() result.set_mlmd_artifact_type(artifact_type) result.set_mlmd_artifact(artifact) return result
def test_put_events_with_paths(self): store = _get_metadata_store() execution_type = metadata_store_pb2.ExecutionType() execution_type.name = "execution_type" execution_type_id = store.put_execution_type(execution_type) execution = metadata_store_pb2.Execution() execution.type_id = execution_type_id [execution_id] = store.put_executions([execution]) artifact_type = metadata_store_pb2.ArtifactType() artifact_type.name = "artifact_type" artifact_type_id = store.put_artifact_type(artifact_type) artifact_0 = metadata_store_pb2.Artifact() artifact_0.type_id = artifact_type_id artifact_1 = metadata_store_pb2.Artifact() artifact_1.type_id = artifact_type_id [artifact_id_0, artifact_id_1] = store.put_artifacts([artifact_0, artifact_1]) event_0 = metadata_store_pb2.Event() event_0.type = metadata_store_pb2.Event.DECLARED_INPUT event_0.artifact_id = artifact_id_0 event_0.execution_id = execution_id event_0.path.steps.add().key = "ggg" event_1 = metadata_store_pb2.Event() event_1.type = metadata_store_pb2.Event.DECLARED_INPUT event_1.artifact_id = artifact_id_1 event_1.execution_id = execution_id event_1.path.steps.add().key = "fff" store.put_events([event_0, event_1]) [event_result_0, event_result_1] = store.get_events_by_execution_ids([execution_id]) self.assertLen(event_result_0.path.steps, 1) self.assertEqual(event_result_0.path.steps[0].key, "ggg") self.assertLen(event_result_1.path.steps, 1) self.assertEqual(event_result_1.path.steps[0].key, "fff")
def test_put_artifacts_get_artifacts(self): store = _get_metadata_store() artifact_type = _create_example_artifact_type( self._get_test_type_name()) type_id = store.put_artifact_type(artifact_type) artifact_0 = metadata_store_pb2.Artifact() artifact_0.type_id = type_id artifact_0.properties["foo"].int_value = 3 artifact_0.properties["bar"].string_value = "Hello" artifact_1 = metadata_store_pb2.Artifact() artifact_1.type_id = type_id existing_artifacts_count = 0 try: existing_artifacts_count = len(store.get_artifacts()) except errors.NotFoundError: existing_artifacts_count = 0 [artifact_id_0, artifact_id_1] = store.put_artifacts([artifact_0, artifact_1]) artifact_result = store.get_artifacts() new_artifacts_count = len(artifact_result) artifact_result = [ a for a in artifact_result if a.id == artifact_id_0 or a.id == artifact_id_1 ] if artifact_result[0].id == artifact_id_0: [artifact_result_0, artifact_result_1] = artifact_result else: [artifact_result_1, artifact_result_0] = artifact_result self.assertEqual(existing_artifacts_count + 2, new_artifacts_count) self.assertEqual(artifact_result_0.id, artifact_id_0) self.assertEqual(artifact_result_0.properties["bar"].string_value, "Hello") self.assertEqual(artifact_result_0.properties["foo"].int_value, 3) self.assertEqual(artifact_result_1.id, artifact_id_1)
def put_artifact(self, properties: Dict[str, str]) -> int: """Inserts or updates an artifact in the fake database. Args: properties: A dictionary of custom properties and values to be added to the artifact. Returns: An artifact id corresponding with the input. """ artifact = metadata_store_pb2.Artifact() artifact.uri = 'test/path/' artifact.type_id = self.artifact_type_id for name, val in properties.items(): artifact.custom_properties[name].string_value = val return self.store.put_artifacts([artifact])[0]