def fakeUpstreamOutputs(mlmd_connection: metadata.Metadata, example_gen: pipeline_pb2.PipelineNode, transform: pipeline_pb2.PipelineNode): with mlmd_connection as m: if example_gen: # Publishes ExampleGen output. output_example = types.Artifact( example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example.uri = 'my_examples_uri' contexts = context_lib.prepare_contexts( m, example_gen.contexts) execution = execution_publish_utils.register_execution( m, example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example], }) if transform: # Publishes Transform output. output_transform_graph = types.Artifact( transform.outputs.outputs['transform_graph'].artifact_spec. type) output_example.uri = 'my_transform_graph_uri' contexts = context_lib.prepare_contexts(m, transform.contexts) execution = execution_publish_utils.register_execution( m, transform.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'transform_graph': [output_transform_graph], })
def generate_resolved_info(metadata_handler: metadata.Metadata, node: pipeline_pb2.PipelineNode) -> ResolvedInfo: """Returns a `ResolvedInfo` object for executing the node. Args: metadata_handler: A handler to access MLMD db. node: The pipeline node for which to generate. Returns: A `ResolvedInfo` with input resolutions. """ # Register node contexts. contexts = context_lib.prepare_contexts( metadata_handler=metadata_handler, node_contexts=node.contexts) # Resolve execution properties. exec_properties = inputs_utils.resolve_parameters( node_parameters=node.parameters) # Resolve inputs. input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=metadata_handler, node_inputs=node.inputs) return ResolvedInfo( contexts=contexts, exec_properties=exec_properties, input_artifacts=input_artifacts)
def testRegisterContexts(self): node_contexts = pipeline_pb2.NodeContexts() self.load_proto_from_text( os.path.join(self._testdata_dir, 'node_context_spec.pbtxt'), node_contexts) with metadata.Metadata(connection_config=self._connection_config) as m: context_lib.prepare_contexts(metadata_handler=m, node_contexts=node_contexts) # Duplicated call should succeed. contexts = context_lib.prepare_contexts( metadata_handler=m, node_contexts=node_contexts) got_context_type_one = m.store.get_context_type( 'my_context_type_one') got_context_type_one.ClearField('id') self.assertProtoEquals( """ name: 'my_context_type_one' """, got_context_type_one) got_context_type_two = m.store.get_context_type( 'my_context_type_two') got_context_type_two.ClearField('id') self.assertProtoEquals( """ name: 'my_context_type_two' """, got_context_type_two) self.assertEqual( contexts[0], m.store.get_context_by_type_and_name('my_context_type_one', 'my_context_one')) self.assertEqual( contexts[1], m.store.get_context_by_type_and_name('my_context_type_one', 'my_context_two')) self.assertEqual( contexts[2], m.store.get_context_by_type_and_name('my_context_type_two', 'my_context_three')) self.assertEqual( contexts[0].custom_properties['property_a'].int_value, 1) self.assertEqual( contexts[1].custom_properties['property_a'].int_value, 2) self.assertEqual( contexts[2].custom_properties['property_a'].int_value, 3) self.assertEqual( contexts[2].custom_properties['property_b'].string_value, '4')
def fake_execute(self, metadata_handler, pipeline_node, input_map, output_map): contexts = context_lib.prepare_contexts(metadata_handler, pipeline_node.contexts) execution = execution_publish_utils.register_execution( metadata_handler, pipeline_node.node_info.type, contexts, input_map) return execution_publish_utils.publish_succeeded_execution( metadata_handler, execution.id, contexts, output_map)
def fake_cached_execution(mlmd_connection, cache_context, component): """Writes cached execution; MLMD must have previous execution associated with cache_context.""" with mlmd_connection as m: cached_outputs = cache_utils.get_cached_outputs( m, cache_context=cache_context) contexts = context_lib.prepare_contexts(m, component.contexts) execution = execution_publish_utils.register_execution( m, component.node_info.type, contexts) execution_publish_utils.publish_cached_execution( m, contexts=contexts, execution_id=execution.id, output_artifacts=cached_outputs)
def run( self, mlmd_connection: metadata.Metadata, pipeline_node: pipeline_pb2.PipelineNode, pipeline_info: pipeline_pb2.PipelineInfo, pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec ) -> data_types.ExecutionInfo: """Runs Resolver specific logic. Args: mlmd_connection: ML metadata connection. pipeline_node: The specification of the node that this launcher lauches. pipeline_info: The information of the pipeline that this node runs in. pipeline_runtime_spec: The runtime information of the pipeline that this node runs in. Returns: The execution of the run. """ logging.info('Running as an resolver node.') with mlmd_connection as m: # 1.Prepares all contexts. contexts = context_lib.prepare_contexts( metadata_handler=m, node_contexts=pipeline_node.contexts) # 2. Resolves inputs an execution properties. exec_properties = inputs_utils.resolve_parameters( node_parameters=pipeline_node.parameters) input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=pipeline_node.inputs) # 3. Registers execution in metadata. execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=pipeline_node.node_info.type, contexts=contexts, exec_properties=exec_properties) # 4. Publish the execution as a cached execution with # resolved input artifact as the output artifacts. execution_publish_utils.publish_internal_execution( metadata_handler=m, contexts=contexts, execution_id=execution.id, output_artifacts=input_artifacts) return data_types.ExecutionInfo(execution_id=execution.id, input_dict=input_artifacts, output_dict=input_artifacts, exec_properties=exec_properties, pipeline_node=pipeline_node, pipeline_info=pipeline_info)
def fake_trainer_output(mlmd_connection, trainer, execution=None): """Writes fake trainer output and execution to MLMD.""" with mlmd_connection as m: output_trainer_model = types.Artifact( trainer.outputs.outputs['model'].artifact_spec.type) output_trainer_model.uri = 'my_trainer_model_uri' contexts = context_lib.prepare_contexts(m, trainer.contexts) if not execution: execution = execution_publish_utils.register_execution( m, trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_trainer_model], })
def fake_example_gen_run_with_handle(mlmd_handle, example_gen, span, version): """Writes fake example_gen output and successful execution to MLMD.""" output_example = types.Artifact( example_gen.outputs.outputs['output_examples'].artifact_spec.type) output_example.set_int_custom_property('span', span) output_example.set_int_custom_property('version', version) output_example.uri = 'my_examples_uri' contexts = context_lib.prepare_contexts(mlmd_handle, example_gen.contexts) execution = execution_publish_utils.register_execution( mlmd_handle, example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( mlmd_handle, execution.id, contexts, { 'output_examples': [output_example], })
def fake_transform_output(mlmd_connection, transform, execution=None): """Writes fake transform output and execution to MLMD.""" with mlmd_connection as m: output_transform_graph = types.Artifact( transform.outputs.outputs['transform_graph'].artifact_spec.type) output_transform_graph.uri = 'my_transform_graph_uri' contexts = context_lib.prepare_contexts(m, transform.contexts) if not execution: execution = execution_publish_utils.register_execution( m, transform.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'transform_graph': [output_transform_graph], })
def fake_component_output_with_handle(mlmd_handle, component, execution=None, active=False): """Writes fake component output and execution to MLMD.""" output_key, output_value = next(iter(component.outputs.outputs.items())) output = types.Artifact(output_value.artifact_spec.type) output.uri = str(uuid.uuid4()) contexts = context_lib.prepare_contexts(mlmd_handle, component.contexts) if not execution: execution = execution_publish_utils.register_execution( mlmd_handle, component.node_info.type, contexts) if not active: execution_publish_utils.publish_succeeded_execution( mlmd_handle, execution.id, contexts, {output_key: [output]})
def generate_resolved_info( metadata_handler: metadata.Metadata, node: pipeline_pb2.PipelineNode) -> Optional[ResolvedInfo]: """Returns a `ResolvedInfo` object for executing the node or `None` to skip. Args: metadata_handler: A handler to access MLMD db. node: The pipeline node for which to generate. Returns: A `ResolvedInfo` with input resolutions or `None` if execution should be skipped. Raises: NotImplementedError: Multiple dicts returned by inputs_utils resolve_input_artifacts_v2, which is currently not supported. """ # Register node contexts. contexts = context_lib.prepare_contexts(metadata_handler=metadata_handler, node_contexts=node.contexts) # Resolve execution properties. exec_properties = data_types_utils.build_parsed_value_dict( inputs_utils.resolve_parameters_with_schema( node_parameters=node.parameters)) # Resolve inputs. try: resolved_input_artifacts = inputs_utils.resolve_input_artifacts_v2( metadata_handler=metadata_handler, pipeline_node=node) except exceptions.InputResolutionError as e: logging.warning( 'Input resolution error raised for node: %s; error: %s', node.node_info.id, e) resolved_input_artifacts = None else: if isinstance(resolved_input_artifacts, inputs_utils.Skip): return None assert isinstance(resolved_input_artifacts, inputs_utils.Trigger) assert resolved_input_artifacts # TODO(b/197741942): Support multiple dicts. if len(resolved_input_artifacts) > 1: raise NotImplementedError( 'Handling more than one input dicts not implemented.') return ResolvedInfo(contexts=contexts, exec_properties=exec_properties, input_artifacts=resolved_input_artifacts)
def fakeExampleGenOutput(mlmd_connection: metadata.Metadata, example_gen: pipeline_pb2.PipelineNode, span: int, version: int): with mlmd_connection as m: output_example = types.Artifact( example_gen.outputs.outputs['output_examples'].artifact_spec.type) output_example.set_int_custom_property('span', span) output_example.set_int_custom_property('version', version) output_example.uri = 'my_examples_uri' contexts = context_lib.prepare_contexts(m, example_gen.contexts) execution = execution_publish_utils.register_execution( m, example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example], })
def _generate_contexts(self, metadata_handler): context_spec = pipeline_pb2.NodeContexts() text_format.Parse( """ contexts { type {name: 'pipeline_context'} name { field_value {string_value: 'my_pipeline'} } } contexts { type {name: 'component_context'} name { field_value {string_value: 'my_component'} } }""", context_spec) return context_lib.prepare_contexts(metadata_handler, context_spec)
def testResolverWithResolverPolicy(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(self._testdata_dir, 'pipeline_for_input_resolver_test.pbtxt'), pipeline) my_example_gen = pipeline.nodes[0].pipeline_node my_transform = pipeline.nodes[2].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes first ExampleGen with two output channels. `output_examples` # will be consumed by downstream Transform. output_example_1 = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example_1.uri = 'my_examples_uri_1' output_example_2 = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example_2.uri = 'my_examples_uri_2' contexts = context_lib.prepare_contexts(m, my_example_gen.contexts) execution = execution_publish_utils.register_execution( m, my_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example_1, output_example_2], }) my_transform.inputs.resolver_config.resolver_policy = ( pipeline_pb2.ResolverConfig.LATEST_ARTIFACT) # Gets inputs for transform. Should get back what the first ExampleGen # published in the `output_examples` channel. transform_inputs = inputs_utils.resolve_input_artifacts( m, my_transform.inputs) self.assertEqual(len(transform_inputs), 1) self.assertEqual(len(transform_inputs['examples']), 1) self.assertProtoPartiallyEquals( transform_inputs['examples'][0].mlmd_artifact, output_example_2.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def testResolveInputArtifactsOutputKeyUnset(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( self._testdata_dir, 'pipeline_for_input_resolver_test_output_key_unset.pbtxt'), pipeline) my_trainer = pipeline.nodes[0].pipeline_node my_pusher = pipeline.nodes[1].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes Trainer with one output channels. `output_model` # will be consumed by the Pusher in the different run. output_model = types.Artifact( my_trainer.outputs.outputs['model'].artifact_spec.type) output_model.uri = 'my_output_model_uri' contexts = context_lib.prepare_contexts(m, my_trainer.contexts) execution = execution_publish_utils.register_execution( m, my_trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model], }) # Gets inputs for pusher. Should get back what the first Model # published in the `output_model` channel. pusher_inputs = inputs_utils.resolve_input_artifacts( m, my_pusher.inputs) self.assertEqual(len(pusher_inputs), 1) self.assertEqual(len(pusher_inputs['model']), 1) self.assertProtoPartiallyEquals( output_model.mlmd_artifact, pusher_inputs['model'][0].mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def _prepare_execution(self) -> _PrepareExecutionResult: """Prepares inputs, outputs and execution properties for actual execution.""" # TODO(b/150979622): handle the edge case that the component get evicted # between successful pushlish and stateful working dir being clean up. # Otherwise following retries will keep failing because of duplicate # publishes. with self._mlmd_connection as m: # 1.Prepares all contexts. contexts = context_lib.prepare_contexts( metadata_handler=m, node_contexts=self._pipeline_node.contexts) # 2. Resolves inputs an execution properties. exec_properties = inputs_utils.resolve_parameters( node_parameters=self._pipeline_node.parameters) input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=self._pipeline_node.inputs) # 3. If not all required inputs are met. Return ExecutionInfo with # is_execution_needed being false. No publish will happen so down stream # nodes won't be triggered. if input_artifacts is None: logging.info( 'No all required input are ready, abandoning execution.') return _PrepareExecutionResult( execution_info=data_types.ExecutionInfo(), contexts=contexts, is_execution_needed=False) # 4. Registers execution in metadata. execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=self._pipeline_node.node_info.type, contexts=contexts, input_artifacts=input_artifacts, exec_properties=exec_properties) # 5. Resolve output output_artifacts = self._output_resolver.generate_output_artifacts( execution.id) # If there is a custom driver, runs it. if self._driver_operator: driver_output = self._driver_operator.run_driver( data_types.ExecutionInfo( input_dict=input_artifacts, output_dict=output_artifacts, exec_properties=exec_properties, execution_output_uri=self._output_resolver. get_driver_output_uri())) self._update_with_driver_output(driver_output, exec_properties, output_artifacts) # We reconnect to MLMD here because the custom driver closes MLMD connection # on returning. with self._mlmd_connection as m: # 6. Check cached result cache_context = cache_utils.get_cache_context( metadata_handler=m, pipeline_node=self._pipeline_node, pipeline_info=self._pipeline_info, executor_spec=self._executor_spec, input_artifacts=input_artifacts, output_artifacts=output_artifacts, parameters=exec_properties) contexts.append(cache_context) cached_outputs = cache_utils.get_cached_outputs( metadata_handler=m, cache_context=cache_context) # 7. Should cache be used? if (self._pipeline_node.execution_options.caching_options. enable_cache and cached_outputs): # Publishes cache result execution_publish_utils.publish_cached_execution( metadata_handler=m, contexts=contexts, execution_id=execution.id, output_artifacts=cached_outputs) logging.info('An cached execusion %d is used.', execution.id) return _PrepareExecutionResult( execution_info=data_types.ExecutionInfo( execution_id=execution.id), execution_metadata=execution, contexts=contexts, is_execution_needed=False) pipeline_run_id = (self._pipeline_runtime_spec.pipeline_run_id. field_value.string_value) # 8. Going to trigger executor. logging.info('Going to run a new execution %d', execution.id) return _PrepareExecutionResult( execution_info=data_types.ExecutionInfo( execution_id=execution.id, input_dict=input_artifacts, output_dict=output_artifacts, exec_properties=exec_properties, execution_output_uri=self._output_resolver. get_executor_output_uri(execution.id), stateful_working_dir=(self._output_resolver. get_stateful_working_directory()), tmp_dir=self._output_resolver.make_tmp_dir(execution.id), pipeline_node=self._pipeline_node, pipeline_info=self._pipeline_info, pipeline_run_id=pipeline_run_id), execution_metadata=execution, contexts=contexts, is_execution_needed=True)
def testLauncher_ReEntry(self): # Some executors or runtime environment may reschedule the launcher job # before the launcher job can publish any results of the execution to MLMD. # The launcher should reuse the previous execution and proceed to a # successful execution. self.reloadPipelineWithNewRunId() LauncherTest.fakeUpstreamOutputs(self._mlmd_connection, self._example_gen, self._transform) def create_test_launcher(executor_operators): return launcher.Launcher( pipeline_node=self._trainer, mlmd_connection=self._mlmd_connection, pipeline_info=self._pipeline_info, pipeline_runtime_spec=self._pipeline_runtime_spec, executor_spec=self._trainer_executor_spec, custom_executor_operators=executor_operators) test_launcher = create_test_launcher( {_PYTHON_CLASS_EXECUTABLE_SPEC: _FakeCrashingExecutorOperator}) # The first launch simulates the launcher being restarted by preventing the # publishing of any results to MLMD. with contextlib.ExitStack() as stack: stack.enter_context( mock.patch.object(test_launcher, '_publish_failed_execution')) stack.enter_context( mock.patch.object(test_launcher, '_clean_up_stateless_execution_info')) stack.enter_context(self.assertRaises(FakeError)) test_launcher.launch() # Retrieve execution from the first launch, which should be in RUNNING # state. with self._mlmd_connection as m: contexts = context_lib.prepare_contexts( metadata_handler=m, node_contexts=test_launcher._pipeline_node.contexts) exec_properties = data_types_utils.build_parsed_value_dict( inputs_utils.resolve_parameters_with_schema( node_parameters=test_launcher._pipeline_node.parameters)) input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=test_launcher._pipeline_node.inputs) first_execution = test_launcher._register_or_reuse_execution( metadata_handler=m, contexts=contexts, input_artifacts=input_artifacts, exec_properties=exec_properties) self.assertEqual(first_execution.last_known_state, metadata_store_pb2.Execution.RUNNING) # Create a second test launcher. It should reuse the previous execution. second_test_launcher = create_test_launcher( {_PYTHON_CLASS_EXECUTABLE_SPEC: _FakeExecutorOperator}) execution_info = second_test_launcher.launch() with self._mlmd_connection as m: self.assertEqual(first_execution.id, execution_info.execution_id) executions = m.store.get_executions_by_id( [execution_info.execution_id]) self.assertLen(executions, 1) self.assertEqual(executions.pop().last_known_state, metadata_store_pb2.Execution.COMPLETE) # Create a third test launcher. It should not require an execution. third_test_launcher = create_test_launcher( {_PYTHON_CLASS_EXECUTABLE_SPEC: _FakeExecutorOperator}) execution_preparation_result = third_test_launcher._prepare_execution() self.assertFalse(execution_preparation_result.is_execution_needed)
def test_resolver_task_scheduler(self): with self._mlmd_connection as m: # Publishes two models which will be consumed by downstream resolver. output_model_1 = types.Artifact( self._trainer.outputs.outputs['model'].artifact_spec.type) output_model_1.uri = 'my_model_uri_1' output_model_2 = types.Artifact( self._trainer.outputs.outputs['model'].artifact_spec.type) output_model_2.uri = 'my_model_uri_2' contexts = context_lib.prepare_contexts(m, self._trainer.contexts) execution = execution_publish_utils.register_execution( m, self._trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model_1, output_model_2], }) task_queue = tq.TaskQueue() # Verify that resolver task is generated. [resolver_task] = test_utils.run_generator_and_test( test_case=self, mlmd_connection=self._mlmd_connection, generator_class=sptg.SyncPipelineTaskGenerator, pipeline=self._pipeline, task_queue=task_queue, use_task_queue=False, service_job_manager=None, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._resolver_node], ignore_update_node_state_tasks=True) with self._mlmd_connection as m: # Run resolver task scheduler and publish results. ts_result = resolver_task_scheduler.ResolverTaskScheduler( mlmd_handle=m, pipeline=self._pipeline, task=resolver_task).schedule() self.assertEqual(status_lib.Code.OK, ts_result.status.code) self.assertIsInstance(ts_result.output, task_scheduler.ResolverNodeOutput) self.assertCountEqual( ['resolved_model'], ts_result.output.resolved_input_artifacts.keys()) models = ts_result.output.resolved_input_artifacts[ 'resolved_model'] self.assertLen(models, 1) self.assertEqual('my_model_uri_2', models[0].mlmd_artifact.uri) tm._publish_execution_results(m, resolver_task, ts_result) # Verify resolver node output is input to the downstream consumer node. [consumer_task] = test_utils.run_generator_and_test( test_case=self, mlmd_connection=self._mlmd_connection, generator_class=sptg.SyncPipelineTaskGenerator, pipeline=self._pipeline, task_queue=task_queue, use_task_queue=False, service_job_manager=None, num_initial_executions=2, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._consumer_node], ignore_update_node_state_tasks=True) self.assertCountEqual(['resolved_model'], consumer_task.input_artifacts.keys()) input_models = consumer_task.input_artifacts['resolved_model'] self.assertLen(input_models, 1) self.assertEqual('my_model_uri_2', input_models[0].mlmd_artifact.uri)
def testSuccess(self): with self._mlmd_connection as m: # Publishes two models which will be consumed by downstream resolver. output_model_1 = types.Artifact( self._my_trainer.outputs.outputs['model'].artifact_spec.type) output_model_1.uri = 'my_model_uri_1' output_model_2 = types.Artifact( self._my_trainer.outputs.outputs['model'].artifact_spec.type) output_model_2.uri = 'my_model_uri_2' contexts = context_lib.prepare_contexts(m, self._my_trainer.contexts) execution = execution_publish_utils.register_execution( m, self._my_trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model_1, output_model_2], }) handler = resolver_node_handler.ResolverNodeHandler() execution_metadata = handler.run( mlmd_connection=self._mlmd_connection, pipeline_node=self._resolver_node, pipeline_info=self._pipeline_info, pipeline_runtime_spec=self._pipeline_runtime_spec) with self._mlmd_connection as m: # There is no way to directly verify the output artifact of the resolver # So here a fake downstream component is created which listens to the # resolver's output and we verify its input. down_stream_node = text_format.Parse( """ inputs { inputs { key: "input_models" value { channels { producer_node_query { id: "my_resolver" } context_queries { type { name: "pipeline" } name { field_value { string_value: "my_pipeline" } } } context_queries { type { name: "component" } name { field_value { string_value: "my_resolver" } } } artifact_query { type { name: "Model" } } output_key: "models" } min_count: 1 } } } upstream_nodes: "my_resolver" """, pipeline_pb2.PipelineNode()) downstream_input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=down_stream_node.inputs) downstream_input_model = downstream_input_artifacts['input_models'] self.assertLen(downstream_input_model, 1) self.assertProtoPartiallyEquals( """ id: 2 type_id: 5 uri: "my_model_uri_2" state: LIVE""", downstream_input_model[0].mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) [execution] = m.store.get_executions_by_id([execution_metadata.id]) self.assertProtoPartiallyEquals(""" id: 2 type_id: 6 last_known_state: COMPLETE """, execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def testResolverInputsArtifacts(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(self._testdata_dir, 'pipeline_for_input_resolver_test.pbtxt'), pipeline) my_example_gen = pipeline.nodes[0].pipeline_node another_example_gen = pipeline.nodes[1].pipeline_node my_transform = pipeline.nodes[2].pipeline_node my_trainer = pipeline.nodes[3].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes first ExampleGen with two output channels. `output_examples` # will be consumed by downstream Transform. output_example = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example.uri = 'my_examples_uri' side_examples = types.Artifact( my_example_gen.outputs.outputs['side_examples'].artifact_spec. type) side_examples.uri = 'side_examples_uri' contexts = context_lib.prepare_contexts(m, my_example_gen.contexts) execution = execution_publish_utils.register_execution( m, my_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example], 'another_examples': [side_examples] }) # Publishes second ExampleGen with one output channel with the same output # key as the first ExampleGen. However this is not consumed by downstream # nodes. another_output_example = types.Artifact( another_example_gen.outputs.outputs['output_examples']. artifact_spec.type) another_output_example.uri = 'another_examples_uri' contexts = context_lib.prepare_contexts( m, another_example_gen.contexts) execution = execution_publish_utils.register_execution( m, another_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [another_output_example], }) # Gets inputs for transform. Should get back what the first ExampleGen # published in the `output_examples` channel. transform_inputs = inputs_utils.resolve_input_artifacts( m, my_transform.inputs) self.assertEqual(len(transform_inputs), 1) self.assertEqual(len(transform_inputs['examples']), 1) self.assertProtoPartiallyEquals( transform_inputs['examples'][0].mlmd_artifact, output_example.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) # Tries to resolve inputs for trainer. As trainer also requires min_count # for both input channels (from example_gen and from transform) but we did # not publish anything from transform, it should return nothing. self.assertIsNone( inputs_utils.resolve_input_artifacts(m, my_trainer.inputs))
def run( self, mlmd_connection: metadata.Metadata, pipeline_node: pipeline_pb2.PipelineNode, pipeline_info: pipeline_pb2.PipelineInfo, pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec ) -> data_types.ExecutionInfo: """Runs Resolver specific logic. Args: mlmd_connection: ML metadata connection. pipeline_node: The specification of the node that this launcher lauches. pipeline_info: The information of the pipeline that this node runs in. pipeline_runtime_spec: The runtime information of the pipeline that this node runs in. Returns: The execution of the run. """ logging.info('Running as an resolver node.') with mlmd_connection as m: # 1.Prepares all contexts. contexts = context_lib.prepare_contexts( metadata_handler=m, node_contexts=pipeline_node.contexts) # 2. Resolves inputs and execution properties. exec_properties = data_types_utils.build_parsed_value_dict( inputs_utils.resolve_parameters_with_schema( node_parameters=pipeline_node.parameters)) try: resolved_inputs = inputs_utils.resolve_input_artifacts_v2( pipeline_node=pipeline_node, metadata_handler=m) except exceptions.InputResolutionError as e: execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=pipeline_node.node_info.type, contexts=contexts, exec_properties=exec_properties) execution_publish_utils.publish_failed_execution( metadata_handler=m, contexts=contexts, execution_id=execution.id, executor_output=self._build_error_output( code=e.grpc_code_value)) return data_types.ExecutionInfo( execution_id=execution.id, exec_properties=exec_properties, pipeline_node=pipeline_node, pipeline_info=pipeline_info) # 2a. If Skip (i.e. inside conditional), no execution should be made. # TODO(b/197907821): Publish special execution for Skip? if isinstance(resolved_inputs, inputs_utils.Skip): return data_types.ExecutionInfo() # 3. Registers execution in metadata. execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=pipeline_node.node_info.type, contexts=contexts, exec_properties=exec_properties) # TODO(b/197741942): Support len > 1. if len(resolved_inputs) > 1: execution_publish_utils.publish_failed_execution( metadata_handler=m, contexts=contexts, execution_id=execution.id, executor_output=self._build_error_output( _ERROR_CODE_UNIMPLEMENTED, 'Handling more than one input dicts not implemented yet.' )) return data_types.ExecutionInfo( execution_id=execution.id, exec_properties=exec_properties, pipeline_node=pipeline_node, pipeline_info=pipeline_info) input_artifacts = resolved_inputs[0] # 4. Publish the execution as a cached execution with # resolved input artifact as the output artifacts. execution_publish_utils.publish_internal_execution( metadata_handler=m, contexts=contexts, execution_id=execution.id, output_artifacts=input_artifacts) return data_types.ExecutionInfo(execution_id=execution.id, input_dict=input_artifacts, output_dict=input_artifacts, exec_properties=exec_properties, pipeline_node=pipeline_node, pipeline_info=pipeline_info)
def run( self, mlmd_connection: metadata.Metadata, pipeline_node: pipeline_pb2.PipelineNode, pipeline_info: pipeline_pb2.PipelineInfo, pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec ) -> data_types.ExecutionInfo: """Runs Importer specific logic. Args: mlmd_connection: ML metadata connection. pipeline_node: The specification of the node that this launcher lauches. pipeline_info: The information of the pipeline that this node runs in. pipeline_runtime_spec: The runtime information of the pipeline that this node runs in. Returns: The execution of the run. """ logging.info('Running as an importer node.') with mlmd_connection as m: # 1.Prepares all contexts. contexts = context_lib.prepare_contexts( metadata_handler=m, node_contexts=pipeline_node.contexts) # 2. Resolves execution properties, please note that importers has no # input. exec_properties = data_types_utils.build_parsed_value_dict( inputs_utils.resolve_parameters_with_schema( node_parameters=pipeline_node.parameters)) # 3. Registers execution in metadata. execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=pipeline_node.node_info.type, contexts=contexts, exec_properties=exec_properties) # 4. Generate output artifacts to represent the imported artifacts. output_spec = pipeline_node.outputs.outputs[ importer.IMPORT_RESULT_KEY] properties = self._extract_proto_map( output_spec.artifact_spec.additional_properties) custom_properties = self._extract_proto_map( output_spec.artifact_spec.additional_custom_properties) output_artifact_class = types.Artifact( output_spec.artifact_spec.type).type output_artifacts = importer.generate_output_dict( metadata_handler=m, uri=str(exec_properties[importer.SOURCE_URI_KEY]), properties=properties, custom_properties=custom_properties, reimport=bool(exec_properties[importer.REIMPORT_OPTION_KEY]), output_artifact_class=output_artifact_class, mlmd_artifact_type=output_spec.artifact_spec.type) result = data_types.ExecutionInfo(execution_id=execution.id, input_dict={}, output_dict=output_artifacts, exec_properties=exec_properties, pipeline_node=pipeline_node, pipeline_info=pipeline_info) # TODO(b/182316162): consider let the launcher level do the publish # for system nodes. So that the version taging logic doesn't need to be # handled per system node. outputs_utils.tag_output_artifacts_with_version(result.output_dict) # 5. Publish the output artifacts. If artifacts are reimported, the # execution is published as CACHED. Otherwise it is published as COMPLETE. if _is_artifact_reimported(output_artifacts): execution_publish_utils.publish_cached_execution( metadata_handler=m, contexts=contexts, execution_id=execution.id, output_artifacts=output_artifacts) else: execution_publish_utils.publish_succeeded_execution( metadata_handler=m, execution_id=execution.id, contexts=contexts, output_artifacts=output_artifacts) return result
def testRegisterContexts(self): node_contexts = pipeline_pb2.NodeContexts() self.load_proto_from_text( os.path.join(self._testdata_dir, 'node_context_spec.pbtxt'), node_contexts) with metadata.Metadata(connection_config=self._connection_config) as m: context_lib.prepare_contexts(metadata_handler=m, node_contexts=node_contexts) # Duplicated call should succeed. contexts = context_lib.prepare_contexts( metadata_handler=m, node_contexts=node_contexts) self.assertProtoEquals( """ id: 1 name: 'my_context_type_one' """, m.store.get_context_type('my_context_type_one')) self.assertProtoEquals( """ id: 2 name: 'my_context_type_two' """, m.store.get_context_type('my_context_type_two')) self.assertProtoEquals( """ type_id: 1 name: "my_context_one" custom_properties { key: "property_a" value { int_value: 1 } } """, contexts[0]) self.assertProtoEquals( """ type_id: 1 name: "my_context_two" custom_properties { key: "property_a" value { int_value: 2 } } """, contexts[1]) self.assertProtoEquals( """ type_id: 2 name: "my_context_three" custom_properties { key: "property_a" value { int_value: 3 } } custom_properties { key: "property_b" value { string_value: '4' } } """, contexts[2]) self.assertEqual( contexts[0].custom_properties['property_a'].int_value, 1) self.assertEqual( contexts[1].custom_properties['property_a'].int_value, 2) self.assertEqual( contexts[2].custom_properties['property_a'].int_value, 3) self.assertEqual( contexts[2].custom_properties['property_b'].string_value, '4')
def run( self, mlmd_connection: metadata.Metadata, pipeline_node: pipeline_pb2.PipelineNode, pipeline_info: pipeline_pb2.PipelineInfo, pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec ) -> data_types.ExecutionInfo: """Runs Importer specific logic. Args: mlmd_connection: ML metadata connection. pipeline_node: The specification of the node that this launcher lauches. pipeline_info: The information of the pipeline that this node runs in. pipeline_runtime_spec: The runtime information of the pipeline that this node runs in. Returns: The execution of the run. """ logging.info('Running as an importer node.') with mlmd_connection as m: # 1.Prepares all contexts. contexts = context_lib.prepare_contexts( metadata_handler=m, node_contexts=pipeline_node.contexts) # 2. Resolves execution properties, please note that importers has no # input. exec_properties = inputs_utils.resolve_parameters( node_parameters=pipeline_node.parameters) # 3. Registers execution in metadata. execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=pipeline_node.node_info.type, contexts=contexts, exec_properties=exec_properties) # 4. Generate output artifacts to represent the imported artifacts. output_spec = pipeline_node.outputs.outputs[importer.IMPORT_RESULT_KEY] properties = self._extract_proto_map( output_spec.artifact_spec.additional_properties) custom_properties = self._extract_proto_map( output_spec.artifact_spec.additional_custom_properties) output_artifact_class = types.Artifact( output_spec.artifact_spec.type).type output_artifacts = importer.generate_output_dict( metadata_handler=m, uri=str(exec_properties[importer.SOURCE_URI_KEY]), properties=properties, custom_properties=custom_properties, reimport=bool(exec_properties[importer.REIMPORT_OPTION_KEY]), output_artifact_class=output_artifact_class, mlmd_artifact_type=output_spec.artifact_spec.type) # 5. Publish the output artifacts. execution_publish_utils.publish_succeeded_execution( metadata_handler=m, execution_id=execution.id, contexts=contexts, output_artifacts=output_artifacts) return data_types.ExecutionInfo( execution_id=execution.id, input_dict={}, output_dict=output_artifacts, exec_properties=exec_properties, pipeline_node=pipeline_node, pipeline_info=pipeline_info)