def setUp(self): super().setUp() with self.get_metadata() as m: common_utils.register_type_if_not_exist( m, metadata_store_pb2.ExecutionType(name='Transform')) common_utils.register_type_if_not_exist( m, metadata_store_pb2.ExecutionType(name='Trainer'))
def testRegisterTypeReuseExisting(self, metadata_type_class): with metadata.Metadata(connection_config=self._connection_config) as m: type_with_two_properties = _create_type(metadata_type_class) result_one = common_utils.register_type_if_not_exist( m, type_with_two_properties) # Tries to register a type that shares the same name of the type # previously registered but with no properties. We expect the previously # registered type to be reused. type_without_properties = metadata_type_class() text_format.Parse("name: 'my_type'", type_without_properties) result_two = common_utils.register_type_if_not_exist( m, type_without_properties) self.assertProtoEquals(result_one, result_two)
def testGetArtifactsDict(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Create and shuffle a few artifacts. The shuffled order should be # retained in the output of `execution_lib.get_artifacts_dict`. input_examples = [] for i in range(10): input_example = standard_artifacts.Examples() input_example.uri = 'example{}'.format(i) input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id input_examples.append(input_example) random.shuffle(input_examples) output_models = [] for i in range(8): output_model = standard_artifacts.Model() output_model.uri = 'model{}'.format(i) output_model.type_id = common_utils.register_type_if_not_exist( m, output_model.artifact_type).id output_models.append(output_model) random.shuffle(output_models) m.store.put_artifacts([ a.mlmd_artifact for a in itertools.chain(input_examples, output_models) ]) execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), state=metadata_store_pb2.Execution.RUNNING) contexts = self._generate_contexts(m) input_artifacts_dict = {'examples': input_examples} output_artifacts_dict = {'model': output_models} execution = execution_lib.put_execution( m, execution, contexts, input_artifacts=input_artifacts_dict, output_artifacts=output_artifacts_dict) # Verify that the same artifacts are returned in the correct order. artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, metadata_store_pb2.Event.INPUT) self.assertCountEqual(['examples'], list(artifacts_dict.keys())) self.assertEqual([ex.uri for ex in input_examples], [a.uri for a in artifacts_dict['examples']]) artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, metadata_store_pb2.Event.OUTPUT) self.assertCountEqual(['model'], list(artifacts_dict.keys())) self.assertEqual([model.uri for model in output_models], [a.uri for a in artifacts_dict['model']])
def prepare_execution( metadata_handler: metadata.Metadata, execution_type: metadata_store_pb2.ExecutionType, state: metadata_store_pb2.Execution.State, exec_properties: Optional[Mapping[Text, types.Property]] = None, ) -> metadata_store_pb2.Execution: """Creates an execution proto based on the information provided. Args: metadata_handler: A handler to access MLMD store. execution_type: A metadata_pb2.ExecutionType message describing the type of the execution. state: The state of the execution. exec_properties: Execution properties that need to be attached. Returns: A metadata_store_pb2.Execution message. """ execution = metadata_store_pb2.Execution() execution.last_known_state = state execution.type_id = common_utils.register_type_if_not_exist( metadata_handler, execution_type).id exec_properties = exec_properties or {} # For every execution property, put it in execution.properties if its key is # in execution type schema. Otherwise, put it in execution.custom_properties. for k, v in exec_properties.items(): if (execution_type.properties.get(k) == common_utils.get_metadata_value_type(v)): common_utils.set_metadata_value(execution.properties[k], v) else: common_utils.set_metadata_value(execution.custom_properties[k], v) logging.debug('Prepared EXECUTION:\n %s', execution) return execution
def testPutExecutionGraph(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Prepares an input artifact. The artifact should be registered in MLMD # before the put_execution call. input_example = standard_artifacts.Examples() input_example.uri = 'example' input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id [input_example.id ] = m.store.put_artifacts([input_example.mlmd_artifact]) # Prepares an output artifact. output_model = standard_artifacts.Model() output_model.uri = 'model' execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), exec_properties={ 'p1': 1, 'p2': '2' }, state=metadata_store_pb2.Execution.COMPLETE) contexts = self._generate_contexts(m) execution = execution_lib.put_execution( m, execution, contexts, input_artifacts={'example': [input_example]}, output_artifacts={'model': [output_model]}) self.assertProtoPartiallyEquals( output_model.mlmd_artifact, m.store.get_artifacts_by_id([output_model.id])[0], ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) # Verifies edges between artifacts and execution. [input_event ] = m.store.get_events_by_artifact_ids([input_example.id]) self.assertEqual(input_event.execution_id, execution.id) self.assertEqual(input_event.type, metadata_store_pb2.Event.INPUT) [output_event ] = m.store.get_events_by_artifact_ids([output_model.id]) self.assertEqual(output_event.execution_id, execution.id) self.assertEqual(output_event.type, metadata_store_pb2.Event.OUTPUT) # Verifies edges connecting contexts and {artifacts, execution}. context_ids = [context.id for context in contexts] self.assertCountEqual([ c.id for c in m.store.get_contexts_by_artifact(input_example.id) ], context_ids) self.assertCountEqual([ c.id for c in m.store.get_contexts_by_artifact(output_model.id) ], context_ids) self.assertCountEqual([ c.id for c in m.store.get_contexts_by_execution(execution.id) ], context_ids)
def testRegisterTypeModifiedKey(self, metadata_type_class): with metadata.Metadata(connection_config=self._connection_config) as m: type_with_two_properties = _create_type(metadata_type_class) common_utils.register_type_if_not_exist(m, type_with_two_properties) # Tries to register a type that shares the same name of the type # previously registered but with conflicting property types. We expect # this to fail. type_with_different_properties = metadata_type_class() text_format.Parse( """ name: 'my_type' properties { key: 'p1' value: STRING # This is different from the original registered type } """, type_with_different_properties) with self.assertRaisesRegex(RuntimeError, 'Missing or modified key'): common_utils.register_type_if_not_exist( m, type_with_different_properties)
def testGetArtifactIdsForExecutionIdGroupedByEventType(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Register an input and output artifacts in MLMD. input_example = standard_artifacts.Examples() input_example.uri = 'example' input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id output_model = standard_artifacts.Model() output_model.uri = 'model' output_model.type_id = common_utils.register_type_if_not_exist( m, output_model.artifact_type).id [input_example.id, output_model.id] = m.store.put_artifacts( [input_example.mlmd_artifact, output_model.mlmd_artifact]) execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), exec_properties={ 'p1': 1, 'p2': '2' }, state=metadata_store_pb2.Execution.COMPLETE) contexts = self._generate_contexts(m) execution = execution_lib.put_execution( m, execution, contexts, input_artifacts={'example': [input_example]}, output_artifacts={'model': [output_model]}) artifact_ids_by_event_type = ( execution_lib.get_artifact_ids_by_event_type_for_execution_id( m, execution.id)) self.assertDictEqual( { metadata_store_pb2.Event.INPUT: set([input_example.id]), metadata_store_pb2.Event.OUTPUT: set([output_model.id]), }, artifact_ids_by_event_type)
def testRegisterType(self, metadata_type_class): with metadata.Metadata(connection_config=self._connection_config) as m: type_with_two_properties = _create_type(metadata_type_class) result_one = common_utils.register_type_if_not_exist( m, type_with_two_properties) self.assertProtoEquals( """ id: 1 name: 'my_type' properties { key: 'p1' value: INT } properties { key: 'p2' value: STRING } """, result_one)
def prepare_execution( metadata_handler: metadata.Metadata, execution_type: metadata_store_pb2.ExecutionType, state: metadata_store_pb2.Execution.State, exec_properties: Optional[Mapping[str, types.ExecPropertyTypes]] = None, ) -> metadata_store_pb2.Execution: """Creates an execution proto based on the information provided. Args: metadata_handler: A handler to access MLMD store. execution_type: A metadata_pb2.ExecutionType message describing the type of the execution. state: The state of the execution. exec_properties: Execution properties that need to be attached. Returns: A metadata_store_pb2.Execution message. """ execution = metadata_store_pb2.Execution() execution.last_known_state = state execution.type_id = common_utils.register_type_if_not_exist( metadata_handler, execution_type).id exec_properties = exec_properties or {} # For every execution property, put it in execution.properties if its key is # in execution type schema. Otherwise, put it in execution.custom_properties. for k, v in exec_properties.items(): value = pipeline_pb2.Value() value = data_types_utils.set_parameter_value(value, v) if value.HasField('schema'): # Stores schema in custom_properties for non-primitive types to allow # parsing in later stages. data_types_utils.set_metadata_value( execution.custom_properties[get_schema_key(k)], proto_utils.proto_to_json(value.schema)) if (execution_type.properties.get(k) == data_types_utils.get_metadata_value_type(v)): execution.properties[k].CopyFrom(value.field_value) else: execution.custom_properties[k].CopyFrom(value.field_value) logging.debug('Prepared EXECUTION:\n %s', execution) return execution
def _create_artifact_and_event_pairs( metadata_handler: metadata.Metadata, artifact_dict: MutableMapping[Text, Sequence[types.Artifact]], event_type: metadata_store_pb2.Event.Type, ) -> List[Tuple[metadata_store_pb2.Artifact, metadata_store_pb2.Event]]: """Creates a list of [Artifact, Event] tuples. The result of this function will be used in a MLMD put_execution() call. Args: metadata_handler: A handler to access MLMD store. artifact_dict: The source of artifacts to work on. For each artifact in the dict, creates a tuple for that. Note that all artifacts of the same key in the artifact_dict are expected to share the same artifact type. event_type: The event type of the event to be attached to the artifact Returns: A list of [Artifact, Event] tuples """ result = [] for key, artifact_list in artifact_dict.items(): artifact_type = None for index, artifact in enumerate(artifact_list): # TODO(b/153904840): If artifact id is present, skip putting the artifact # into the pair when MLMD API is ready. event = event_lib.generate_event(event_type=event_type, key=key, index=index) # Reuses already registered type in the same list whenever possible as # the artifacts in the same list share the same artifact type. if artifact_type: assert artifact_type.name == artifact.artifact_type.name, ( 'Artifacts under the same key should share the same artifact type.' ) artifact_type = common_utils.register_type_if_not_exist( metadata_handler, artifact.artifact_type) artifact.set_mlmd_artifact_type(artifact_type) result.append((artifact.mlmd_artifact, event)) return result
def _generate_context_proto( metadata_handler: metadata.Metadata, context_spec: pipeline_pb2.ContextSpec) -> metadata_store_pb2.Context: """Generates metadata_pb2.Context based on the ContextSpec message. Args: metadata_handler: A handler to access MLMD store. context_spec: A pipeline_pb2.ContextSpec message that instructs registering of a context. Returns: A metadata_store_pb2.Context message. Raises: RuntimeError: When actual property type does not match provided metadata type schema. """ context_type = common_utils.register_type_if_not_exist( metadata_handler, context_spec.type) context_name = common_utils.get_value(context_spec.name) assert isinstance(context_name, Text), 'context name should be string.' context = metadata_store_pb2.Context(type_id=context_type.id, name=context_name) for k, v in context_spec.properties.items(): if k in context_type.properties: actual_property_type = common_utils.get_metadata_value_type(v) if context_type.properties.get(k) == actual_property_type: common_utils.set_metadata_value(context.properties[k], v) else: raise RuntimeError( 'Property type %s different from provided metadata type property type %s for key %s' % (actual_property_type, context_type.properties.get(k), k)) else: common_utils.set_metadata_value(context.custom_properties[k], v) return context
def testGetArtifactsDict(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Create and shuffle a few artifacts. The shuffled order should be # retained in the output of `execution_lib.get_artifacts_dict`. input_artifact_keys = ('input1', 'input2', 'input3') input_artifacts_dict = collections.OrderedDict() for input_key in input_artifact_keys: input_examples = [] for i in range(10): input_example = standard_artifacts.Examples() input_example.uri = f'{input_key}/example{i}' input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id input_examples.append(input_example) random.shuffle(input_examples) input_artifacts_dict[input_key] = input_examples output_models = [] for i in range(8): output_model = standard_artifacts.Model() output_model.uri = f'model{i}' output_model.type_id = common_utils.register_type_if_not_exist( m, output_model.artifact_type).id output_models.append(output_model) random.shuffle(output_models) output_artifacts_dict = {'model': output_models} # Store input artifacts only. Outputs will be saved in put_execution(). input_mlmd_artifacts = [ a.mlmd_artifact for a in itertools.chain(*input_artifacts_dict.values()) ] artifact_ids = m.store.put_artifacts(input_mlmd_artifacts) for artifact_id, mlmd_artifact in zip(artifact_ids, input_mlmd_artifacts): mlmd_artifact.id = artifact_id execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), state=metadata_store_pb2.Execution.RUNNING) contexts = self._generate_contexts(m) # Change the order of the OrderedDict to shuffle the order of input keys. input_artifacts_dict.move_to_end('input1') execution = execution_lib.put_execution( m, execution, contexts, input_artifacts=input_artifacts_dict, output_artifacts=output_artifacts_dict) # Verify that the same artifacts are returned in the correct order. artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, [metadata_store_pb2.Event.INPUT]) self.assertEqual(set(input_artifact_keys), set(artifacts_dict.keys())) for key in artifacts_dict: self.assertEqual([ex.uri for ex in input_artifacts_dict[key]], [a.uri for a in artifacts_dict[key]], f'for key={key}') artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, [metadata_store_pb2.Event.OUTPUT]) self.assertEqual({'model'}, set(artifacts_dict.keys())) self.assertEqual([model.uri for model in output_models], [a.uri for a in artifacts_dict['model']])