def fake_execute(self, metadata_handler, pipeline_node, input_map, output_map): contexts = context_lib.prepare_contexts(metadata_handler, pipeline_node.contexts) execution = execution_publish_utils.register_execution( metadata_handler, pipeline_node.node_info.type, contexts, input_map) return execution_publish_utils.publish_succeeded_execution( metadata_handler, execution.id, contexts, output_map)
def testGetCachedOutputArtifactsForNodesWithNoOuput(self): with metadata.Metadata(connection_config=self._connection_config) as m: cache_context = context_lib.register_context_if_not_exists( m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key') cached_output = cache_utils.get_cached_outputs(m, cache_context) # No succeed execution is associate with this context yet, so the cached # output is None. self.assertIsNone(cached_output) execution_one = execution_publish_utils.register_execution( m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context]) execution_publish_utils.publish_succeeded_execution( m, execution_one.id, [cache_context]) cached_output = cache_utils.get_cached_outputs(m, cache_context) # A succeed execution is associate with this context, so the cached # output is not None but an empty dict. self.assertIsNotNone(cached_output) self.assertEmpty(cached_output)
def fake_trainer_output(mlmd_connection, trainer, execution=None, active=False): """Writes fake trainer output and execution to MLMD.""" with mlmd_connection as m: output_trainer_model = types.Artifact( trainer.outputs.outputs['model'].artifact_spec.type) output_trainer_model.uri = 'my_trainer_model_uri' contexts = context_lib.prepare_contexts(m, trainer.contexts) if not execution: execution = execution_publish_utils.register_execution( m, trainer.node_info.type, contexts) if not active: execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_trainer_model], })
def fakeExampleGenOutput(mlmd_connection: metadata.Metadata, example_gen: pipeline_pb2.PipelineNode, span: int, version: int): with mlmd_connection as m: output_example = types.Artifact( example_gen.outputs.outputs['output_examples'].artifact_spec. type) output_example.set_int_custom_property('span', span) output_example.set_int_custom_property('version', version) output_example.uri = 'my_examples_uri' contexts = context_lib.register_contexts_if_not_exists( m, example_gen.contexts) execution = execution_publish_utils.register_execution( m, example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example], })
def testPublishSuccessExecutionFailChangedUriDir(self): output_example = standard_artifacts.Examples() output_example.uri = '/my/original_uri' output_dict = {'examples': [output_example]} with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id executor_output = execution_result_pb2.ExecutorOutput() new_example = executor_output.output_artifacts[ 'examples'].artifacts.add() new_example.uri = '/my/new_uri/1' with self.assertRaisesRegex( RuntimeError, 'When there is one artifact to publish, the URI of it should be ' 'identical to the URI of system generated artifact.'): execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, output_dict, executor_output)
def fake_component_output_with_handle(mlmd_handle, component, execution=None, active=False, exec_properties=None): """Writes fake component output and execution to MLMD.""" output_key, output_value = next(iter(component.outputs.outputs.items())) output = types.Artifact(output_value.artifact_spec.type) output.uri = str(uuid.uuid4()) contexts = context_lib.prepare_contexts(mlmd_handle, component.contexts) if not execution: execution = execution_publish_utils.register_execution( mlmd_handle, component.node_info.type, contexts, exec_properties=exec_properties) if not active: execution_publish_utils.publish_succeeded_execution( mlmd_handle, execution.id, contexts, {output_key: [output]})
def testPublishSuccessExecutionFailTooManyLayerOfSubDir(self): output_example = standard_artifacts.Examples() output_example.uri = '/my/original_uri' output_dict = {'examples': [output_example]} with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id executor_output = execution_result_pb2.ExecutorOutput() new_example = executor_output.output_artifacts[ 'examples'].artifacts.add() new_example.uri = '/my/original_uri/1/1' with self.assertRaisesRegex( RuntimeError, 'The URI of executor generated artifacts should either be identical ' 'to the URI of system generated artifact or be a direct sub-dir of ' 'it.'): execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, output_dict, executor_output)
def fake_execute_node(mlmd_connection, task, artifact_custom_properties=None): """Simulates node execution given ExecNodeTask.""" node = task.get_pipeline_node() with mlmd_connection as m: if node.HasField('outputs'): output_key, output_value = next(iter(node.outputs.outputs.items())) output = types.Artifact(output_value.artifact_spec.type) if artifact_custom_properties: for key, val in artifact_custom_properties.items(): if isinstance(val, int): output.set_int_custom_property(key, val) elif isinstance(val, str): output.set_string_custom_property(key, val) else: raise ValueError(f'unsupported type: {type(val)}') output.uri = str(uuid.uuid4()) output_artifacts = {output_key: [output]} else: output_artifacts = None execution_publish_utils.publish_succeeded_execution( m, task.execution_id, task.contexts, output_artifacts)
def _publish_execution_results(mlmd_handle: metadata.Metadata, task: task_lib.ExecNodeTask, result: ts.TaskSchedulerResult) -> None: """Publishes execution results to MLMD.""" def _update_state(status: status_lib.Status) -> None: assert status.code != status_lib.Code.OK if status.code == status_lib.Code.CANCELLED: logging.info('Cancelling execution (id: %s); task id: %s; status: %s', task.execution.id, task.task_id, status) execution_state = metadata_store_pb2.Execution.CANCELED else: logging.info( 'Aborting execution (id: %s) due to error (code: %s); task id: %s', task.execution.id, status.code, task.task_id) execution_state = metadata_store_pb2.Execution.FAILED _update_execution_state_in_mlmd(mlmd_handle, task.execution, execution_state, status.message) if result.status.code != status_lib.Code.OK: _update_state(result.status) return publish_params = dict(output_artifacts=task.output_artifacts) if result.output_artifacts is not None: publish_params['output_artifacts'] = result.output_artifacts elif result.executor_output is not None: if result.executor_output.execution_result.code != status_lib.Code.OK: _update_state( status_lib.Status( code=result.executor_output.execution_result.code, message=result.executor_output.execution_result.result_message)) return publish_params['executor_output'] = result.executor_output execution_publish_utils.publish_succeeded_execution(mlmd_handle, task.execution.id, task.contexts, **publish_params)
def testResolveInputArtifactsOutputKeyUnset(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( self._testdata_dir, 'pipeline_for_input_resolver_test_output_key_unset.pbtxt'), pipeline) my_trainer = pipeline.nodes[0].pipeline_node my_pusher = pipeline.nodes[1].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes Trainer with one output channels. `output_model` # will be consumed by the Pusher in the different run. output_model = types.Artifact( my_trainer.outputs.outputs['model'].artifact_spec.type) output_model.uri = 'my_output_model_uri' contexts = context_lib.register_contexts_if_not_exists( m, my_trainer.contexts) execution = execution_publish_utils.register_execution( m, my_trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model], }) # Gets inputs for pusher. Should get back what the first Model # published in the `output_model` channel. pusher_inputs = inputs_utils.resolve_input_artifacts( m, my_pusher.inputs) self.assertEqual(len(pusher_inputs), 1) self.assertEqual(len(pusher_inputs['model']), 1) self.assertProtoPartiallyEquals( output_model.mlmd_artifact, pusher_inputs['model'][0].mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def testPublishSuccessExecutionFailInvalidUri(self, invalid_uri): output_example = standard_artifacts.Examples() output_example.uri = '/my/original_uri' output_dict = {'examples': [output_example]} with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id executor_output = execution_result_pb2.ExecutorOutput() system_generated_artifact = executor_output.output_artifacts[ 'examples'].artifacts.add() system_generated_artifact.uri = '/my/original_uri/0' new_artifact = executor_output.output_artifacts[ 'examples'].artifacts.add() new_artifact.uri = invalid_uri with self.assertRaisesRegex( RuntimeError, 'When there are multiple artifacts to publish, their URIs should be ' 'direct sub-directories of the URI of the system generated artifact.' ): execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, output_dict, executor_output)
def _publish_execution_results(mlmd_handle: metadata.Metadata, task: task_lib.ExecNodeTask, result: ts.TaskSchedulerResult) -> None: """Publishes execution results to MLMD.""" def _update_state(status: status_lib.Status) -> None: assert status.code != status_lib.Code.OK if status.code == status_lib.Code.CANCELLED: execution_state = metadata_store_pb2.Execution.CANCELED state_msg = 'cancelled' else: execution_state = metadata_store_pb2.Execution.FAILED state_msg = 'failed' logging.info( 'Got error (status: %s) for task id: %s; marking execution (id: %s) ' 'as %s.', status, task.task_id, task.execution.id, state_msg) # TODO(goutham): Also record error code and error message as custom property # of the execution. _update_execution_state_in_mlmd(mlmd_handle, task.execution, execution_state) if result.status.code != status_lib.Code.OK: _update_state(result.status) return if (result.executor_output and result.executor_output.execution_result.code != status_lib.Code.OK): _update_state(status_lib.Status( code=result.executor_output.execution_result.code, message=result.executor_output.execution_result.result_message)) return execution_publish_utils.publish_succeeded_execution(mlmd_handle, task.execution.id, task.contexts, task.output_artifacts, result.executor_output)
def _publish_execution_results(mlmd_handle: metadata.Metadata, task: task_lib.ExecNodeTask, result: ts.TaskSchedulerResult) -> None: """Publishes execution results to MLMD.""" def _update_state(status: status_lib.Status) -> None: assert status.code != status_lib.Code.OK _remove_output_dirs(task, result) _remove_task_dirs(task) if status.code == status_lib.Code.CANCELLED: logging.info( 'Cancelling execution (id: %s); task id: %s; status: %s', task.execution_id, task.task_id, status) execution_state = metadata_store_pb2.Execution.CANCELED else: logging.info( 'Aborting execution (id: %s) due to error (code: %s); task id: %s', task.execution_id, status.code, task.task_id) execution_state = metadata_store_pb2.Execution.FAILED _update_execution_state_in_mlmd(mlmd_handle, task.execution_id, execution_state, status.message) pipeline_state.record_state_change_time() if result.status.code != status_lib.Code.OK: _update_state(result.status) return # TODO(b/182316162): Unify publisher handing so that post-execution artifact # logic is more cleanly handled. outputs_utils.tag_output_artifacts_with_version(task.output_artifacts) if isinstance(result.output, ts.ExecutorNodeOutput): executor_output = result.output.executor_output if executor_output is not None: if executor_output.execution_result.code != status_lib.Code.OK: _update_state( status_lib.Status( code=executor_output.execution_result.code, message=executor_output.execution_result.result_message )) return # TODO(b/182316162): Unify publisher handing so that post-execution # artifact logic is more cleanly handled. outputs_utils.tag_executor_output_with_version(executor_output) _remove_task_dirs(task) execution_publish_utils.publish_succeeded_execution( mlmd_handle, execution_id=task.execution_id, contexts=task.contexts, output_artifacts=task.output_artifacts, executor_output=executor_output) elif isinstance(result.output, ts.ImporterNodeOutput): output_artifacts = result.output.output_artifacts # TODO(b/182316162): Unify publisher handing so that post-execution artifact # logic is more cleanly handled. outputs_utils.tag_output_artifacts_with_version(output_artifacts) _remove_task_dirs(task) execution_publish_utils.publish_succeeded_execution( mlmd_handle, execution_id=task.execution_id, contexts=task.contexts, output_artifacts=output_artifacts) elif isinstance(result.output, ts.ResolverNodeOutput): resolved_input_artifacts = result.output.resolved_input_artifacts execution_publish_utils.publish_internal_execution( mlmd_handle, execution_id=task.execution_id, contexts=task.contexts, output_artifacts=resolved_input_artifacts) else: raise TypeError(f'Unable to process task scheduler result: {result}') pipeline_state.record_state_change_time()
def testSuccess(self): with self._mlmd_connection as m: # Publishes two models which will be consumed by downstream resolver. output_model_1 = types.Artifact( self._my_trainer.outputs.outputs['model'].artifact_spec.type) output_model_1.uri = 'my_model_uri_1' output_model_2 = types.Artifact( self._my_trainer.outputs.outputs['model'].artifact_spec.type) output_model_2.uri = 'my_model_uri_2' contexts = context_lib.prepare_contexts(m, self._my_trainer.contexts) execution = execution_publish_utils.register_execution( m, self._my_trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model_1, output_model_2], }) handler = resolver_node_handler.ResolverNodeHandler() execution_metadata = handler.run( mlmd_connection=self._mlmd_connection, pipeline_node=self._resolver_node, pipeline_info=self._pipeline_info, pipeline_runtime_spec=self._pipeline_runtime_spec) with self._mlmd_connection as m: # There is no way to directly verify the output artifact of the resolver # So here a fake downstream component is created which listens to the # resolver's output and we verify its input. down_stream_node = text_format.Parse( """ inputs { inputs { key: "input_models" value { channels { producer_node_query { id: "my_resolver" } context_queries { type { name: "pipeline" } name { field_value { string_value: "my_pipeline" } } } context_queries { type { name: "component" } name { field_value { string_value: "my_resolver" } } } artifact_query { type { name: "Model" } } output_key: "models" } min_count: 1 } } } upstream_nodes: "my_resolver" """, pipeline_pb2.PipelineNode()) downstream_input_artifacts = inputs_utils.resolve_input_artifacts( metadata_handler=m, node_inputs=down_stream_node.inputs) downstream_input_model = downstream_input_artifacts['input_models'] self.assertLen(downstream_input_model, 1) self.assertProtoPartiallyEquals( """ id: 2 type_id: 5 uri: "my_model_uri_2" state: LIVE""", downstream_input_model[0].mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) [execution] = m.store.get_executions_by_id([execution_metadata.id]) self.assertProtoPartiallyEquals(""" id: 2 type_id: 6 last_known_state: COMPLETE """, execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def testResolverInputsArtifacts(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(self._testdata_dir, 'pipeline_for_input_resolver_test.pbtxt'), pipeline) my_example_gen = pipeline.nodes[0].pipeline_node another_example_gen = pipeline.nodes[1].pipeline_node my_transform = pipeline.nodes[2].pipeline_node my_trainer = pipeline.nodes[3].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes first ExampleGen with two output channels. `output_examples` # will be consumed by downstream Transform. output_example = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example.uri = 'my_examples_uri' side_examples = types.Artifact( my_example_gen.outputs.outputs['side_examples'].artifact_spec. type) side_examples.uri = 'side_examples_uri' contexts = context_lib.register_contexts_if_not_exists( m, my_example_gen.contexts) execution = execution_publish_utils.register_execution( m, my_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example], 'another_examples': [side_examples] }) # Publishes second ExampleGen with one output channel with the same output # key as the first ExampleGen. However this is not consumed by downstream # nodes. another_output_example = types.Artifact( another_example_gen.outputs.outputs['output_examples']. artifact_spec.type) another_output_example.uri = 'another_examples_uri' contexts = context_lib.register_contexts_if_not_exists( m, another_example_gen.contexts) execution = execution_publish_utils.register_execution( m, another_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [another_output_example], }) # Gets inputs for transform. Should get back what the first ExampleGen # published in the `output_examples` channel. transform_inputs = inputs_utils.resolve_input_artifacts( m, my_transform.inputs) self.assertEqual(len(transform_inputs), 1) self.assertEqual(len(transform_inputs['examples']), 1) self.assertProtoPartiallyEquals( transform_inputs['examples'][0].mlmd_artifact, output_example.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) # Tries to resolve inputs for trainer. As trainer also requires min_count # for both input channels (from example_gen and from transform) but we did # not publish anything from transform, it should return nothing. self.assertIsNone( inputs_utils.resolve_input_artifacts(m, my_trainer.inputs))
def run( self, mlmd_connection: metadata.Metadata, pipeline_node: pipeline_pb2.PipelineNode, pipeline_info: pipeline_pb2.PipelineInfo, pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec ) -> metadata_store_pb2.Execution: """Runs Importer specific logic. Args: mlmd_connection: ML metadata connection. pipeline_node: The specification of the node that this launcher lauches. pipeline_info: The information of the pipeline that this node runs in. pipeline_runtime_spec: The runtime information of the pipeline that this node runs in. Returns: The execution of the run. """ logging.info('Running as an importer node.') with mlmd_connection as m: # 1.Prepares all contexts. contexts = context_lib.register_contexts_if_not_exists( metadata_handler=m, node_contexts=pipeline_node.contexts) # 2. Resolves execution properties, please note that importers has no # input. exec_properties = inputs_utils.resolve_parameters( node_parameters=pipeline_node.parameters) # 3. Registers execution in metadata. execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=pipeline_node.node_info.type, contexts=contexts, exec_properties=exec_properties) # 4. Generate output artifacts to represent the imported artifacts. output_spec = pipeline_node.outputs.outputs[ importer_node.IMPORT_RESULT_KEY] properties = self._extract_proto_map( output_spec.artifact_spec.additional_properties) custom_properties = self._extract_proto_map( output_spec.artifact_spec.additional_custom_properties) output_artifact_class = types.Artifact( output_spec.artifact_spec.type).type output_artifacts = importer_node.generate_output_dict( metadata_handler=m, uri=str(exec_properties[importer_node.SOURCE_URI_KEY]), properties=properties, custom_properties=custom_properties, reimport=bool( exec_properties[importer_node.REIMPORT_OPTION_KEY]), output_artifact_class=output_artifact_class, mlmd_artifact_type=output_spec.artifact_spec.type) # 5. Publish the output artifacts. execution_publish_utils.publish_succeeded_execution( metadata_handler=m, execution_id=execution.id, contexts=contexts, output_artifacts=output_artifacts) return execution
def testPublishSuccessfulExecution(self): with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id output_key = 'examples' output_example = standard_artifacts.Examples() executor_output = execution_result_pb2.ExecutorOutput() text_format.Parse( """ uri: 'examples_uri' custom_properties { key: 'prop' value {int_value: 1} } """, executor_output.output_artifacts[output_key].artifacts.add()) execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, {output_key: [output_example]}, executor_output) [execution] = m.store.get_executions() self.assertProtoPartiallyEquals( """ id: 1 type_id: 3 last_known_state: COMPLETE """, execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) [artifact] = m.store.get_artifacts() self.assertProtoPartiallyEquals( """ id: 1 type_id: 4 state: LIVE uri: 'examples_uri' custom_properties { key: 'prop' value {int_value: 1} }""", artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) [event] = m.store.get_events_by_execution_ids([execution.id]) self.assertProtoPartiallyEquals( """ artifact_id: 1 execution_id: 1 path { steps { key: 'examples' } steps { index: 0 } } type: OUTPUT """, event, ignored_fields=['milliseconds_since_epoch']) # Verifies the context-execution edges are set up. self.assertCountEqual( [c.id for c in contexts], [c.id for c in m.store.get_contexts_by_execution(execution.id)]) self.assertCountEqual( [c.id for c in contexts], [c.id for c in m.store.get_contexts_by_artifact(output_example.id)])
def testPublishSuccessExecutionExecutorEditedOutputDict(self): # There is one artifact in the system provided output_dict, while there are # two artifacts in executor output. We expect that two artifacts are # published. with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id output_example = standard_artifacts.Examples() output_example.uri = '/original_path' executor_output = execution_result_pb2.ExecutorOutput() output_key = 'examples' text_format.Parse( """ uri: '/original_path/subdir_1' custom_properties { key: 'prop' value {int_value: 1} } """, executor_output.output_artifacts[output_key].artifacts.add()) text_format.Parse( """ uri: '/original_path/subdir_2' custom_properties { key: 'prop' value {int_value: 2} } """, executor_output.output_artifacts[output_key].artifacts.add()) output_dict = execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, {output_key: [output_example]}, executor_output) [execution] = m.store.get_executions() self.assertProtoPartiallyEquals(""" id: 1 type_id: 3 last_known_state: COMPLETE """, execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) artifacts = m.store.get_artifacts() self.assertLen(artifacts, 2) self.assertProtoPartiallyEquals(""" id: 1 type_id: 4 state: LIVE uri: '/original_path/subdir_1' custom_properties { key: 'prop' value {int_value: 1} }""", artifacts[0], ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) self.assertProtoPartiallyEquals(""" id: 2 type_id: 4 state: LIVE uri: '/original_path/subdir_2' custom_properties { key: 'prop' value {int_value: 2} }""", artifacts[1], ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) events = m.store.get_events_by_execution_ids([execution.id]) self.assertLen(events, 2) self.assertProtoPartiallyEquals( """ artifact_id: 1 execution_id: 1 path { steps { key: 'examples' } steps { index: 0 } } type: OUTPUT """, events[0], ignored_fields=['milliseconds_since_epoch']) self.assertProtoPartiallyEquals( """ artifact_id: 2 execution_id: 1 path { steps { key: 'examples' } steps { index: 1 } } type: OUTPUT """, events[1], ignored_fields=['milliseconds_since_epoch']) # Verifies the context-execution edges are set up. self.assertCountEqual([c.id for c in contexts], [ c.id for c in m.store.get_contexts_by_execution(execution.id) ]) for artifact_list in output_dict.values(): for output_example in artifact_list: self.assertCountEqual([c.id for c in contexts], [ c.id for c in m.store.get_contexts_by_artifact( output_example.id) ])
def test_resolver_task_scheduler(self): with self._mlmd_connection as m: # Publishes two models which will be consumed by downstream resolver. output_model_1 = types.Artifact( self._trainer.outputs.outputs['model'].artifact_spec.type) output_model_1.uri = 'my_model_uri_1' output_model_2 = types.Artifact( self._trainer.outputs.outputs['model'].artifact_spec.type) output_model_2.uri = 'my_model_uri_2' contexts = context_lib.prepare_contexts(m, self._trainer.contexts) execution = execution_publish_utils.register_execution( m, self._trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model_1, output_model_2], }) task_queue = tq.TaskQueue() # Verify that resolver task is generated. [resolver_task] = test_utils.run_generator_and_test( test_case=self, mlmd_connection=self._mlmd_connection, generator_class=sptg.SyncPipelineTaskGenerator, pipeline=self._pipeline, task_queue=task_queue, use_task_queue=False, service_job_manager=None, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._resolver_node], ignore_update_node_state_tasks=True) with self._mlmd_connection as m: # Run resolver task scheduler and publish results. ts_result = resolver_task_scheduler.ResolverTaskScheduler( mlmd_handle=m, pipeline=self._pipeline, task=resolver_task).schedule() self.assertEqual(status_lib.Code.OK, ts_result.status.code) self.assertIsInstance(ts_result.output, task_scheduler.ResolverNodeOutput) self.assertCountEqual( ['resolved_model'], ts_result.output.resolved_input_artifacts.keys()) models = ts_result.output.resolved_input_artifacts[ 'resolved_model'] self.assertLen(models, 1) self.assertEqual('my_model_uri_2', models[0].mlmd_artifact.uri) tm._publish_execution_results(m, resolver_task, ts_result) # Verify resolver node output is input to the downstream consumer node. [consumer_task] = test_utils.run_generator_and_test( test_case=self, mlmd_connection=self._mlmd_connection, generator_class=sptg.SyncPipelineTaskGenerator, pipeline=self._pipeline, task_queue=task_queue, use_task_queue=False, service_job_manager=None, num_initial_executions=2, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._consumer_node], ignore_update_node_state_tasks=True) self.assertCountEqual(['resolved_model'], consumer_task.input_artifacts.keys()) input_models = consumer_task.input_artifacts['resolved_model'] self.assertLen(input_models, 1) self.assertEqual('my_model_uri_2', input_models[0].mlmd_artifact.uri)
def testGetCachedOutputArtifacts(self): # Output artifacts that will be used by the first execution with the same # cache key. output_model_one = standard_artifacts.Model() output_model_one.uri = 'model_one' output_model_two = standard_artifacts.Model() output_model_two.uri = 'model_two' output_example_one = standard_artifacts.Examples() output_example_one.uri = 'example_one' # Output artifacts that will be used by the second execution with the same # cache key. output_model_three = standard_artifacts.Model() output_model_three.uri = 'model_three' output_model_four = standard_artifacts.Model() output_model_four.uri = 'model_four' output_example_two = standard_artifacts.Examples() output_example_two.uri = 'example_two' output_models_key = 'output_models' output_examples_key = 'output_examples' with metadata.Metadata(connection_config=self._connection_config) as m: cache_context = context_lib.register_context_if_not_exists( m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key') execution_one = execution_publish_utils.register_execution( m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context]) execution_publish_utils.publish_succeeded_execution( m, execution_one.id, [cache_context], output_artifacts={ output_models_key: [output_model_one, output_model_two], output_examples_key: [output_example_one] }) execution_two = execution_publish_utils.register_execution( m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context]) execution_publish_utils.publish_succeeded_execution( m, execution_two.id, [cache_context], output_artifacts={ output_models_key: [output_model_three, output_model_four], output_examples_key: [output_example_two] }) # The cached output got should be the artifacts produced by the most # recent execution under the given cache context. cached_output = cache_utils.get_cached_outputs(m, cache_context) self.assertLen(cached_output, 2) self.assertLen(cached_output[output_models_key], 2) self.assertLen(cached_output[output_examples_key], 1) self.assertProtoPartiallyEquals( cached_output[output_models_key][0].mlmd_artifact, output_model_three.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) self.assertProtoPartiallyEquals( cached_output[output_models_key][1].mlmd_artifact, output_model_four.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) self.assertProtoPartiallyEquals( cached_output[output_examples_key][0].mlmd_artifact, output_example_two.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def run( self, mlmd_connection: metadata.Metadata, pipeline_node: pipeline_pb2.PipelineNode, pipeline_info: pipeline_pb2.PipelineInfo, pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec ) -> data_types.ExecutionInfo: """Runs Importer specific logic. Args: mlmd_connection: ML metadata connection. pipeline_node: The specification of the node that this launcher lauches. pipeline_info: The information of the pipeline that this node runs in. pipeline_runtime_spec: The runtime information of the pipeline that this node runs in. Returns: The execution of the run. """ logging.info('Running as an importer node.') with mlmd_connection as m: # 1.Prepares all contexts. contexts = context_lib.prepare_contexts( metadata_handler=m, node_contexts=pipeline_node.contexts) # 2. Resolves execution properties, please note that importers has no # input. exec_properties = data_types_utils.build_parsed_value_dict( inputs_utils.resolve_parameters_with_schema( node_parameters=pipeline_node.parameters)) # 3. Registers execution in metadata. execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=pipeline_node.node_info.type, contexts=contexts, exec_properties=exec_properties) # 4. Generate output artifacts to represent the imported artifacts. output_spec = pipeline_node.outputs.outputs[ importer.IMPORT_RESULT_KEY] properties = self._extract_proto_map( output_spec.artifact_spec.additional_properties) custom_properties = self._extract_proto_map( output_spec.artifact_spec.additional_custom_properties) output_artifact_class = types.Artifact( output_spec.artifact_spec.type).type output_artifacts = importer.generate_output_dict( metadata_handler=m, uri=str(exec_properties[importer.SOURCE_URI_KEY]), properties=properties, custom_properties=custom_properties, reimport=bool(exec_properties[importer.REIMPORT_OPTION_KEY]), output_artifact_class=output_artifact_class, mlmd_artifact_type=output_spec.artifact_spec.type) result = data_types.ExecutionInfo(execution_id=execution.id, input_dict={}, output_dict=output_artifacts, exec_properties=exec_properties, pipeline_node=pipeline_node, pipeline_info=pipeline_info) # TODO(b/182316162): consider let the launcher level do the publish # for system nodes. So that the version taging logic doesn't need to be # handled per system node. outputs_utils.tag_output_artifacts_with_version(result.output_dict) # 5. Publish the output artifacts. If artifacts are reimported, the # execution is published as CACHED. Otherwise it is published as COMPLETE. if _is_artifact_reimported(output_artifacts): execution_publish_utils.publish_cached_execution( metadata_handler=m, contexts=contexts, execution_id=execution.id, output_artifacts=output_artifacts) else: execution_publish_utils.publish_succeeded_execution( metadata_handler=m, execution_id=execution.id, contexts=contexts, output_artifacts=output_artifacts) return result