def fakeUpstreamOutputs(mlmd_connection: metadata.Metadata, example_gen: pipeline_pb2.PipelineNode, transform: pipeline_pb2.PipelineNode): with mlmd_connection as m: if example_gen: # Publishes ExampleGen output. output_example = types.Artifact( example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example.uri = 'my_examples_uri' contexts = context_lib.register_contexts_if_not_exists( m, example_gen.contexts) execution = execution_publish_utils.register_execution( m, example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example], }) if transform: # Publishes Transform output. output_transform_graph = types.Artifact( transform.outputs.outputs['transform_graph'].artifact_spec. type) output_example.uri = 'my_transform_graph_uri' contexts = context_lib.register_contexts_if_not_exists( m, transform.contexts) execution = execution_publish_utils.register_execution( m, transform.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'transform_graph': [output_transform_graph], })
def generate_resolved_info(metadata_handler: metadata.Metadata, node: pipeline_pb2.PipelineNode) -> ResolvedInfo: """Returns a `ResolvedInfo` object for executing the node. Args: metadata_handler: A handler to access MLMD db. node: The pipeline node for which to generate. Returns: A `ResolvedInfo` with input resolutions. """ # Register node contexts. contexts = context_lib.register_contexts_if_not_exists( metadata_handler=metadata_handler, node_contexts=node.contexts) # Resolve execution properties. exec_properties = inputs_utils.resolve_parameters( node_parameters=node.parameters) # Resolve inputs. input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=metadata_handler, node_inputs=node.inputs) return ResolvedInfo(contexts=contexts, exec_properties=exec_properties, input_artifacts=input_artifacts)
def testRegisterContexts(self): node_contexts = pipeline_pb2.NodeContexts() self.load_proto_from_text('node_context_spec.pbtxt', node_contexts) with metadata.Metadata(connection_config=self._connection_config) as m: context_lib.register_contexts_if_not_exists( metadata_handler=m, node_contexts=node_contexts) # Duplicated call should succeed. contexts = context_lib.register_contexts_if_not_exists( metadata_handler=m, node_contexts=node_contexts) self.assertProtoEquals( """ id: 1 name: 'my_context_type_one' """, m.store.get_context_type('my_context_type_one')) self.assertProtoEquals( """ id: 2 name: 'my_context_type_two' """, m.store.get_context_type('my_context_type_two')) self.assertEqual( contexts[0], m.store.get_context_by_type_and_name('my_context_type_one', 'my_context_one')) self.assertEqual( contexts[1], m.store.get_context_by_type_and_name('my_context_type_one', 'my_context_two')) self.assertEqual( contexts[2], m.store.get_context_by_type_and_name('my_context_type_two', 'my_context_three')) self.assertEqual( contexts[0].custom_properties['property_a'].int_value, 1) self.assertEqual( contexts[1].custom_properties['property_a'].int_value, 2) self.assertEqual( contexts[2].custom_properties['property_a'].int_value, 3) self.assertEqual( contexts[2].custom_properties['property_b'].string_value, '4')
def fake_trainer_output(mlmd_connection, trainer, execution=None): """Writes fake trainer output and execution to MLMD.""" with mlmd_connection as m: output_trainer_model = types.Artifact( trainer.outputs.outputs['model'].artifact_spec.type) output_trainer_model.uri = 'my_trainer_model_uri' contexts = context_lib.register_contexts_if_not_exists( m, trainer.contexts) if not execution: execution = execution_publish_utils.register_execution( m, trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_trainer_model], })
def fake_transform_output(mlmd_connection, transform, execution=None): """Writes fake transform output and execution to MLMD.""" with mlmd_connection as m: output_transform_graph = types.Artifact( transform.outputs.outputs['transform_graph'].artifact_spec.type) output_transform_graph.uri = 'my_transform_graph_uri' contexts = context_lib.register_contexts_if_not_exists( m, transform.contexts) if not execution: execution = execution_publish_utils.register_execution( m, transform.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'transform_graph': [output_transform_graph], })
def fake_example_gen_run(mlmd_connection, example_gen, span, version): """Writes fake example_gen output and successful execution to MLMD.""" with mlmd_connection as m: output_example = types.Artifact( example_gen.outputs.outputs['output_examples'].artifact_spec.type) output_example.set_int_custom_property('span', span) output_example.set_int_custom_property('version', version) output_example.uri = 'my_examples_uri' contexts = context_lib.register_contexts_if_not_exists( m, example_gen.contexts) execution = execution_publish_utils.register_execution( m, example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example], })
def testResolverWithResolverPolicy(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(self._testdata_dir, 'pipeline_for_input_resolver_test.pbtxt'), pipeline) my_example_gen = pipeline.nodes[0].pipeline_node my_transform = pipeline.nodes[2].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes first ExampleGen with two output channels. `output_examples` # will be consumed by downstream Transform. output_example_1 = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example_1.uri = 'my_examples_uri_1' output_example_2 = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example_2.uri = 'my_examples_uri_2' contexts = context_lib.register_contexts_if_not_exists( m, my_example_gen.contexts) execution = execution_publish_utils.register_execution( m, my_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example_1, output_example_2], }) my_transform.inputs.resolver_config.resolver_policy = ( pipeline_pb2.ResolverConfig.LATEST_ARTIFACT) # Gets inputs for transform. Should get back what the first ExampleGen # published in the `output_examples` channel. transform_inputs = inputs_utils.resolve_input_artifacts( m, my_transform.inputs) self.assertEqual(len(transform_inputs), 1) self.assertEqual(len(transform_inputs['examples']), 1) self.assertProtoPartiallyEquals( transform_inputs['examples'][0].mlmd_artifact, output_example_2.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def fakeExampleGenOutput(mlmd_connection: metadata.Metadata, example_gen: pipeline_pb2.PipelineNode, span: int, version: int): with mlmd_connection as m: output_example = types.Artifact( example_gen.outputs.outputs['output_examples'].artifact_spec.type) output_example.set_int_custom_property('span', span) output_example.set_int_custom_property('version', version) output_example.uri = 'my_examples_uri' contexts = context_lib.register_contexts_if_not_exists( m, example_gen.contexts) execution = execution_publish_utils.register_execution( m, example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example], })
def _generate_contexts(self, metadata_handler): context_spec = pipeline_pb2.NodeContexts() text_format.Parse( """ contexts { type {name: 'pipeline_context'} name { field_value {string_value: 'my_pipeline'} } } contexts { type {name: 'component_context'} name { field_value {string_value: 'my_component'} } }""", context_spec) return context_lib.register_contexts_if_not_exists(metadata_handler, context_spec)
def testResolveInputArtifactsOutputKeyUnset(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( self._testdata_dir, 'pipeline_for_input_resolver_test_output_key_unset.pbtxt'), pipeline) my_trainer = pipeline.nodes[0].pipeline_node my_pusher = pipeline.nodes[1].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes Trainer with one output channels. `output_model` # will be consumed by the Pusher in the different run. output_model = types.Artifact( my_trainer.outputs.outputs['model'].artifact_spec.type) output_model.uri = 'my_output_model_uri' contexts = context_lib.register_contexts_if_not_exists( m, my_trainer.contexts) execution = execution_publish_utils.register_execution( m, my_trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model], }) # Gets inputs for pusher. Should get back what the first Model # published in the `output_model` channel. pusher_inputs = inputs_utils.resolve_input_artifacts( m, my_pusher.inputs) self.assertEqual(len(pusher_inputs), 1) self.assertEqual(len(pusher_inputs['model']), 1) self.assertProtoPartiallyEquals( output_model.mlmd_artifact, pusher_inputs['model'][0].mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def _prepare_execution(self) -> _PrepareExecutionResult: """Prepares inputs, outputs and execution properties for actual execution.""" # TODO(b/150979622): handle the edge case that the component get evicted # between successful pushlish and stateful working dir being clean up. # Otherwise following retries will keep failing because of duplicate # publishes. with self._mlmd_connection as m: # 1.Prepares all contexts. contexts = context_lib.register_contexts_if_not_exists( metadata_handler=m, node_contexts=self._pipeline_node.contexts) # 2. Resolves inputs an execution properties. exec_properties = inputs_utils.resolve_parameters( node_parameters=self._pipeline_node.parameters) input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=self._pipeline_node.inputs) # 3. If not all required inputs are met. Return ExecutionInfo with # is_execution_needed being false. No publish will happen so down stream # nodes won't be triggered. if input_artifacts is None: return _PrepareExecutionResult( execution_info=data_types.ExecutionInfo(), contexts=contexts, is_execution_needed=False) # 4. Registers execution in metadata. execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=self._pipeline_node.node_info.type, contexts=contexts, input_artifacts=input_artifacts, exec_properties=exec_properties) # 5. Resolve output output_artifacts = self._output_resolver.generate_output_artifacts( execution.id) # If there is a custom driver, runs it. if self._driver_operator: driver_output = self._driver_operator.run_driver( data_types.ExecutionInfo( input_dict=input_artifacts, output_dict=output_artifacts, exec_properties=exec_properties, execution_output_uri=self._output_resolver. get_driver_output_uri())) self._update_with_driver_output(driver_output, exec_properties, output_artifacts) # We reconnect to MLMD here because the custom driver closes MLMD connection # on returning. with self._mlmd_connection as m: # 6. Check cached result cache_context = cache_utils.get_cache_context( metadata_handler=m, pipeline_node=self._pipeline_node, pipeline_info=self._pipeline_info, input_artifacts=input_artifacts, output_artifacts=output_artifacts, parameters=exec_properties) contexts.append(cache_context) cached_outputs = cache_utils.get_cached_outputs( metadata_handler=m, cache_context=cache_context) # 7. Should cache be used? if (self._pipeline_node.execution_options.caching_options. enable_cache and cached_outputs): # Publishes cache result execution_publish_utils.publish_cached_execution( metadata_handler=m, contexts=contexts, execution_id=execution.id, output_artifacts=cached_outputs) return _PrepareExecutionResult( execution_info=data_types.ExecutionInfo( execution_id=execution.id), execution_metadata=execution, contexts=contexts, is_execution_needed=False) pipeline_run_id = (self._pipeline_runtime_spec.pipeline_run_id. field_value.string_value) # 8. Going to trigger executor. return _PrepareExecutionResult( execution_info=data_types.ExecutionInfo( execution_id=execution.id, input_dict=input_artifacts, output_dict=output_artifacts, exec_properties=exec_properties, execution_output_uri=self._output_resolver. get_executor_output_uri(execution.id), stateful_working_dir=(self._output_resolver. get_stateful_working_directory()), tmp_dir=self._output_resolver.make_tmp_dir(execution.id), pipeline_node=self._pipeline_node, pipeline_info=self._pipeline_info, pipeline_run_id=pipeline_run_id), execution_metadata=execution, contexts=contexts, is_execution_needed=True)
def testResolverInputsArtifacts(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(self._testdata_dir, 'pipeline_for_input_resolver_test.pbtxt'), pipeline) my_example_gen = pipeline.nodes[0].pipeline_node another_example_gen = pipeline.nodes[1].pipeline_node my_transform = pipeline.nodes[2].pipeline_node my_trainer = pipeline.nodes[3].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes first ExampleGen with two output channels. `output_examples` # will be consumed by downstream Transform. output_example = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example.uri = 'my_examples_uri' side_examples = types.Artifact( my_example_gen.outputs.outputs['side_examples'].artifact_spec. type) side_examples.uri = 'side_examples_uri' contexts = context_lib.register_contexts_if_not_exists( m, my_example_gen.contexts) execution = execution_publish_utils.register_execution( m, my_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example], 'another_examples': [side_examples] }) # Publishes second ExampleGen with one output channel with the same output # key as the first ExampleGen. However this is not consumed by downstream # nodes. another_output_example = types.Artifact( another_example_gen.outputs.outputs['output_examples']. artifact_spec.type) another_output_example.uri = 'another_examples_uri' contexts = context_lib.register_contexts_if_not_exists( m, another_example_gen.contexts) execution = execution_publish_utils.register_execution( m, another_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [another_output_example], }) # Gets inputs for transform. Should get back what the first ExampleGen # published in the `output_examples` channel. transform_inputs = inputs_utils.resolve_input_artifacts( m, my_transform.inputs) self.assertEqual(len(transform_inputs), 1) self.assertEqual(len(transform_inputs['examples']), 1) self.assertProtoPartiallyEquals( transform_inputs['examples'][0].mlmd_artifact, output_example.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) # Tries to resolve inputs for trainer. As trainer also requires min_count # for both input channels (from example_gen and from transform) but we did # not publish anything from transform, it should return nothing. self.assertIsNone( inputs_utils.resolve_input_artifacts(m, my_trainer.inputs))
def run( self, mlmd_connection: metadata.Metadata, pipeline_node: pipeline_pb2.PipelineNode, pipeline_info: pipeline_pb2.PipelineInfo, pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec ) -> metadata_store_pb2.Execution: """Runs Importer specific logic. Args: mlmd_connection: ML metadata connection. pipeline_node: The specification of the node that this launcher lauches. pipeline_info: The information of the pipeline that this node runs in. pipeline_runtime_spec: The runtime information of the pipeline that this node runs in. Returns: The execution of the run. """ logging.info('Running as an importer node.') with mlmd_connection as m: # 1.Prepares all contexts. contexts = context_lib.register_contexts_if_not_exists( metadata_handler=m, node_contexts=pipeline_node.contexts) # 2. Resolves execution properties, please note that importers has no # input. exec_properties = inputs_utils.resolve_parameters( node_parameters=pipeline_node.parameters) # 3. Registers execution in metadata. execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=pipeline_node.node_info.type, contexts=contexts, exec_properties=exec_properties) # 4. Generate output artifacts to represent the imported artifacts. output_spec = pipeline_node.outputs.outputs[ importer_node.IMPORT_RESULT_KEY] properties = self._extract_proto_map( output_spec.artifact_spec.additional_properties) custom_properties = self._extract_proto_map( output_spec.artifact_spec.additional_custom_properties) output_artifact_class = types.Artifact( output_spec.artifact_spec.type).type output_artifacts = importer_node.generate_output_dict( metadata_handler=m, uri=str(exec_properties[importer_node.SOURCE_URI_KEY]), properties=properties, custom_properties=custom_properties, reimport=bool( exec_properties[importer_node.REIMPORT_OPTION_KEY]), output_artifact_class=output_artifact_class, mlmd_artifact_type=output_spec.artifact_spec.type) # 5. Publish the output artifacts. execution_publish_utils.publish_succeeded_execution( metadata_handler=m, execution_id=execution.id, contexts=contexts, output_artifacts=output_artifacts) return execution
def testSuccess(self): with self._mlmd_connection as m: # Publishes two models which will be consumed by downstream resolver. output_model_1 = types.Artifact( self._my_trainer.outputs.outputs['model'].artifact_spec.type) output_model_1.uri = 'my_model_uri_1' output_model_2 = types.Artifact( self._my_trainer.outputs.outputs['model'].artifact_spec.type) output_model_2.uri = 'my_model_uri_2' contexts = context_lib.register_contexts_if_not_exists( m, self._my_trainer.contexts) execution = execution_publish_utils.register_execution( m, self._my_trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model_1, output_model_2], }) handler = resolver_node_handler.ResolverNodeHandler() execution_metadata = handler.run( mlmd_connection=self._mlmd_connection, pipeline_node=self._resolver_node, pipeline_info=self._pipeline_info, pipeline_runtime_spec=self._pipeline_runtime_spec) with self._mlmd_connection as m: # There is no way to directly verify the output artifact of the resolver # So here a fake downstream component is created which listens to the # resolver's output and we verify its input. down_stream_node = text_format.Parse( """ inputs { inputs { key: "input_models" value { channels { producer_node_query { id: "my_resolver" } context_queries { type { name: "pipeline" } name { field_value { string_value: "my_pipeline" } } } context_queries { type { name: "component" } name { field_value { string_value: "my_resolver" } } } artifact_query { type { name: "Model" } } output_key: "models" } min_count: 1 } } } upstream_nodes: "my_resolver" """, pipeline_pb2.PipelineNode()) downstream_input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=down_stream_node.inputs) downstream_input_model = downstream_input_artifacts['input_models'] self.assertLen(downstream_input_model, 1) self.assertProtoPartiallyEquals( """ id: 2 type_id: 5 uri: "my_model_uri_2" state: LIVE""", downstream_input_model[0].mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) [execution] = m.store.get_executions_by_id([execution_metadata.id]) self.assertProtoPartiallyEquals(""" id: 2 type_id: 6 last_known_state: COMPLETE """, execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])