Example #1
0
 def fake_execute(self, metadata_handler, pipeline_node, input_map,
                  output_map):
     contexts = context_lib.prepare_contexts(metadata_handler,
                                             pipeline_node.contexts)
     execution = execution_publish_utils.register_execution(
         metadata_handler, pipeline_node.node_info.type, contexts,
         input_map)
     return execution_publish_utils.publish_succeeded_execution(
         metadata_handler, execution.id, contexts, output_map)
Example #2
0
 def testGetCachedOutputArtifactsForNodesWithNoOuput(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         cache_context = context_lib.register_context_if_not_exists(
             m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key')
         cached_output = cache_utils.get_cached_outputs(m, cache_context)
         # No succeed execution is associate with this context yet, so the cached
         # output is None.
         self.assertIsNone(cached_output)
         execution_one = execution_publish_utils.register_execution(
             m, metadata_store_pb2.ExecutionType(name='my_type'),
             [cache_context])
         execution_publish_utils.publish_succeeded_execution(
             m, execution_one.id, [cache_context])
         cached_output = cache_utils.get_cached_outputs(m, cache_context)
         # A succeed execution is associate with this context, so the cached
         # output is not None but an empty dict.
         self.assertIsNotNone(cached_output)
         self.assertEmpty(cached_output)
Example #3
0
def fake_trainer_output(mlmd_connection,
                        trainer,
                        execution=None,
                        active=False):
    """Writes fake trainer output and execution to MLMD."""
    with mlmd_connection as m:
        output_trainer_model = types.Artifact(
            trainer.outputs.outputs['model'].artifact_spec.type)
        output_trainer_model.uri = 'my_trainer_model_uri'
        contexts = context_lib.prepare_contexts(m, trainer.contexts)
        if not execution:
            execution = execution_publish_utils.register_execution(
                m, trainer.node_info.type, contexts)
        if not active:
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_trainer_model],
                })
Example #4
0
 def fakeExampleGenOutput(mlmd_connection: metadata.Metadata,
                          example_gen: pipeline_pb2.PipelineNode, span: int,
                          version: int):
     with mlmd_connection as m:
         output_example = types.Artifact(
             example_gen.outputs.outputs['output_examples'].artifact_spec.
             type)
         output_example.set_int_custom_property('span', span)
         output_example.set_int_custom_property('version', version)
         output_example.uri = 'my_examples_uri'
         contexts = context_lib.register_contexts_if_not_exists(
             m, example_gen.contexts)
         execution = execution_publish_utils.register_execution(
             m, example_gen.node_info.type, contexts)
         execution_publish_utils.publish_succeeded_execution(
             m, execution.id, contexts, {
                 'output_examples': [output_example],
             })
Example #5
0
    def testPublishSuccessExecutionFailChangedUriDir(self):
        output_example = standard_artifacts.Examples()
        output_example.uri = '/my/original_uri'
        output_dict = {'examples': [output_example]}
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            execution_id = execution_publish_utils.register_execution(
                m, self._execution_type, contexts).id
            executor_output = execution_result_pb2.ExecutorOutput()
            new_example = executor_output.output_artifacts[
                'examples'].artifacts.add()
            new_example.uri = '/my/new_uri/1'

            with self.assertRaisesRegex(
                    RuntimeError,
                    'When there is one artifact to publish, the URI of it should be '
                    'identical to the URI of system generated artifact.'):
                execution_publish_utils.publish_succeeded_execution(
                    m, execution_id, contexts, output_dict, executor_output)
Example #6
0
def fake_component_output_with_handle(mlmd_handle,
                                      component,
                                      execution=None,
                                      active=False,
                                      exec_properties=None):
    """Writes fake component output and execution to MLMD."""
    output_key, output_value = next(iter(component.outputs.outputs.items()))
    output = types.Artifact(output_value.artifact_spec.type)
    output.uri = str(uuid.uuid4())
    contexts = context_lib.prepare_contexts(mlmd_handle, component.contexts)
    if not execution:
        execution = execution_publish_utils.register_execution(
            mlmd_handle,
            component.node_info.type,
            contexts,
            exec_properties=exec_properties)
    if not active:
        execution_publish_utils.publish_succeeded_execution(
            mlmd_handle, execution.id, contexts, {output_key: [output]})
Example #7
0
    def testPublishSuccessExecutionFailTooManyLayerOfSubDir(self):
        output_example = standard_artifacts.Examples()
        output_example.uri = '/my/original_uri'
        output_dict = {'examples': [output_example]}
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            execution_id = execution_publish_utils.register_execution(
                m, self._execution_type, contexts).id
            executor_output = execution_result_pb2.ExecutorOutput()
            new_example = executor_output.output_artifacts[
                'examples'].artifacts.add()
            new_example.uri = '/my/original_uri/1/1'

            with self.assertRaisesRegex(
                    RuntimeError,
                    'The URI of executor generated artifacts should either be identical '
                    'to the URI of system generated artifact or be a direct sub-dir of '
                    'it.'):
                execution_publish_utils.publish_succeeded_execution(
                    m, execution_id, contexts, output_dict, executor_output)
Example #8
0
def fake_execute_node(mlmd_connection, task, artifact_custom_properties=None):
    """Simulates node execution given ExecNodeTask."""
    node = task.get_pipeline_node()
    with mlmd_connection as m:
        if node.HasField('outputs'):
            output_key, output_value = next(iter(node.outputs.outputs.items()))
            output = types.Artifact(output_value.artifact_spec.type)
            if artifact_custom_properties:
                for key, val in artifact_custom_properties.items():
                    if isinstance(val, int):
                        output.set_int_custom_property(key, val)
                    elif isinstance(val, str):
                        output.set_string_custom_property(key, val)
                    else:
                        raise ValueError(f'unsupported type: {type(val)}')
            output.uri = str(uuid.uuid4())
            output_artifacts = {output_key: [output]}
        else:
            output_artifacts = None
        execution_publish_utils.publish_succeeded_execution(
            m, task.execution_id, task.contexts, output_artifacts)
Example #9
0
def _publish_execution_results(mlmd_handle: metadata.Metadata,
                               task: task_lib.ExecNodeTask,
                               result: ts.TaskSchedulerResult) -> None:
  """Publishes execution results to MLMD."""

  def _update_state(status: status_lib.Status) -> None:
    assert status.code != status_lib.Code.OK
    if status.code == status_lib.Code.CANCELLED:
      logging.info('Cancelling execution (id: %s); task id: %s; status: %s',
                   task.execution.id, task.task_id, status)
      execution_state = metadata_store_pb2.Execution.CANCELED
    else:
      logging.info(
          'Aborting execution (id: %s) due to error (code: %s); task id: %s',
          task.execution.id, status.code, task.task_id)
      execution_state = metadata_store_pb2.Execution.FAILED
    _update_execution_state_in_mlmd(mlmd_handle, task.execution,
                                    execution_state, status.message)

  if result.status.code != status_lib.Code.OK:
    _update_state(result.status)
    return

  publish_params = dict(output_artifacts=task.output_artifacts)
  if result.output_artifacts is not None:
    publish_params['output_artifacts'] = result.output_artifacts
  elif result.executor_output is not None:
    if result.executor_output.execution_result.code != status_lib.Code.OK:
      _update_state(
          status_lib.Status(
              code=result.executor_output.execution_result.code,
              message=result.executor_output.execution_result.result_message))
      return
    publish_params['executor_output'] = result.executor_output

  execution_publish_utils.publish_succeeded_execution(mlmd_handle,
                                                      task.execution.id,
                                                      task.contexts,
                                                      **publish_params)
Example #10
0
    def testResolveInputArtifactsOutputKeyUnset(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(
                self._testdata_dir,
                'pipeline_for_input_resolver_test_output_key_unset.pbtxt'),
            pipeline)
        my_trainer = pipeline.nodes[0].pipeline_node
        my_pusher = pipeline.nodes[1].pipeline_node

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        with metadata.Metadata(connection_config=connection_config) as m:
            # Publishes Trainer with one output channels. `output_model`
            # will be consumed by the Pusher in the different run.
            output_model = types.Artifact(
                my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model.uri = 'my_output_model_uri'
            contexts = context_lib.register_contexts_if_not_exists(
                m, my_trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, my_trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model],
                })
            # Gets inputs for pusher. Should get back what the first Model
            # published in the `output_model` channel.
            pusher_inputs = inputs_utils.resolve_input_artifacts(
                m, my_pusher.inputs)
            self.assertEqual(len(pusher_inputs), 1)
            self.assertEqual(len(pusher_inputs['model']), 1)
            self.assertProtoPartiallyEquals(
                output_model.mlmd_artifact,
                pusher_inputs['model'][0].mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
Example #11
0
    def testPublishSuccessExecutionFailInvalidUri(self, invalid_uri):
        output_example = standard_artifacts.Examples()
        output_example.uri = '/my/original_uri'
        output_dict = {'examples': [output_example]}
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            execution_id = execution_publish_utils.register_execution(
                m, self._execution_type, contexts).id
            executor_output = execution_result_pb2.ExecutorOutput()
            system_generated_artifact = executor_output.output_artifacts[
                'examples'].artifacts.add()
            system_generated_artifact.uri = '/my/original_uri/0'
            new_artifact = executor_output.output_artifacts[
                'examples'].artifacts.add()
            new_artifact.uri = invalid_uri

            with self.assertRaisesRegex(
                    RuntimeError,
                    'When there are multiple artifacts to publish, their URIs should be '
                    'direct sub-directories of the URI of the system generated artifact.'
            ):
                execution_publish_utils.publish_succeeded_execution(
                    m, execution_id, contexts, output_dict, executor_output)
Example #12
0
def _publish_execution_results(mlmd_handle: metadata.Metadata,
                               task: task_lib.ExecNodeTask,
                               result: ts.TaskSchedulerResult) -> None:
  """Publishes execution results to MLMD."""

  def _update_state(status: status_lib.Status) -> None:
    assert status.code != status_lib.Code.OK
    if status.code == status_lib.Code.CANCELLED:
      execution_state = metadata_store_pb2.Execution.CANCELED
      state_msg = 'cancelled'
    else:
      execution_state = metadata_store_pb2.Execution.FAILED
      state_msg = 'failed'
    logging.info(
        'Got error (status: %s) for task id: %s; marking execution (id: %s) '
        'as %s.', status, task.task_id, task.execution.id, state_msg)
    # TODO(goutham): Also record error code and error message as custom property
    # of the execution.
    _update_execution_state_in_mlmd(mlmd_handle, task.execution,
                                    execution_state)

  if result.status.code != status_lib.Code.OK:
    _update_state(result.status)
    return

  if (result.executor_output and
      result.executor_output.execution_result.code != status_lib.Code.OK):
    _update_state(status_lib.Status(
        code=result.executor_output.execution_result.code,
        message=result.executor_output.execution_result.result_message))
    return

  execution_publish_utils.publish_succeeded_execution(mlmd_handle,
                                                      task.execution.id,
                                                      task.contexts,
                                                      task.output_artifacts,
                                                      result.executor_output)
Example #13
0
def _publish_execution_results(mlmd_handle: metadata.Metadata,
                               task: task_lib.ExecNodeTask,
                               result: ts.TaskSchedulerResult) -> None:
    """Publishes execution results to MLMD."""
    def _update_state(status: status_lib.Status) -> None:
        assert status.code != status_lib.Code.OK
        _remove_output_dirs(task, result)
        _remove_task_dirs(task)
        if status.code == status_lib.Code.CANCELLED:
            logging.info(
                'Cancelling execution (id: %s); task id: %s; status: %s',
                task.execution_id, task.task_id, status)
            execution_state = metadata_store_pb2.Execution.CANCELED
        else:
            logging.info(
                'Aborting execution (id: %s) due to error (code: %s); task id: %s',
                task.execution_id, status.code, task.task_id)
            execution_state = metadata_store_pb2.Execution.FAILED
        _update_execution_state_in_mlmd(mlmd_handle, task.execution_id,
                                        execution_state, status.message)
        pipeline_state.record_state_change_time()

    if result.status.code != status_lib.Code.OK:
        _update_state(result.status)
        return

    # TODO(b/182316162): Unify publisher handing so that post-execution artifact
    # logic is more cleanly handled.
    outputs_utils.tag_output_artifacts_with_version(task.output_artifacts)
    if isinstance(result.output, ts.ExecutorNodeOutput):
        executor_output = result.output.executor_output
        if executor_output is not None:
            if executor_output.execution_result.code != status_lib.Code.OK:
                _update_state(
                    status_lib.Status(
                        code=executor_output.execution_result.code,
                        message=executor_output.execution_result.result_message
                    ))
                return
            # TODO(b/182316162): Unify publisher handing so that post-execution
            # artifact logic is more cleanly handled.
            outputs_utils.tag_executor_output_with_version(executor_output)
        _remove_task_dirs(task)
        execution_publish_utils.publish_succeeded_execution(
            mlmd_handle,
            execution_id=task.execution_id,
            contexts=task.contexts,
            output_artifacts=task.output_artifacts,
            executor_output=executor_output)
    elif isinstance(result.output, ts.ImporterNodeOutput):
        output_artifacts = result.output.output_artifacts
        # TODO(b/182316162): Unify publisher handing so that post-execution artifact
        # logic is more cleanly handled.
        outputs_utils.tag_output_artifacts_with_version(output_artifacts)
        _remove_task_dirs(task)
        execution_publish_utils.publish_succeeded_execution(
            mlmd_handle,
            execution_id=task.execution_id,
            contexts=task.contexts,
            output_artifacts=output_artifacts)
    elif isinstance(result.output, ts.ResolverNodeOutput):
        resolved_input_artifacts = result.output.resolved_input_artifacts
        execution_publish_utils.publish_internal_execution(
            mlmd_handle,
            execution_id=task.execution_id,
            contexts=task.contexts,
            output_artifacts=resolved_input_artifacts)
    else:
        raise TypeError(f'Unable to process task scheduler result: {result}')

    pipeline_state.record_state_change_time()
Example #14
0
    def testSuccess(self):
        with self._mlmd_connection as m:
            # Publishes two models which will be consumed by downstream resolver.
            output_model_1 = types.Artifact(
                self._my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_1.uri = 'my_model_uri_1'

            output_model_2 = types.Artifact(
                self._my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_2.uri = 'my_model_uri_2'

            contexts = context_lib.prepare_contexts(m,
                                                    self._my_trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, self._my_trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model_1, output_model_2],
                })

        handler = resolver_node_handler.ResolverNodeHandler()
        execution_metadata = handler.run(
            mlmd_connection=self._mlmd_connection,
            pipeline_node=self._resolver_node,
            pipeline_info=self._pipeline_info,
            pipeline_runtime_spec=self._pipeline_runtime_spec)

        with self._mlmd_connection as m:
            # There is no way to directly verify the output artifact of the resolver
            # So here a fake downstream component is created which listens to the
            # resolver's output and we verify its input.
            down_stream_node = text_format.Parse(
                """
        inputs {
          inputs {
            key: "input_models"
            value {
              channels {
                producer_node_query {
                  id: "my_resolver"
                }
                context_queries {
                  type {
                    name: "pipeline"
                  }
                  name {
                    field_value {
                      string_value: "my_pipeline"
                    }
                  }
                }
                context_queries {
                  type {
                    name: "component"
                  }
                  name {
                    field_value {
                      string_value: "my_resolver"
                    }
                  }
                }
                artifact_query {
                  type {
                    name: "Model"
                  }
                }
                output_key: "models"
              }
              min_count: 1
            }
          }
        }
        upstream_nodes: "my_resolver"
        """, pipeline_pb2.PipelineNode())
            downstream_input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=down_stream_node.inputs)
            downstream_input_model = downstream_input_artifacts['input_models']
            self.assertLen(downstream_input_model, 1)
            self.assertProtoPartiallyEquals(
                """
          id: 2
          type_id: 5
          uri: "my_model_uri_2"
          state: LIVE""",
                downstream_input_model[0].mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
            [execution] = m.store.get_executions_by_id([execution_metadata.id])

            self.assertProtoPartiallyEquals("""
          id: 2
          type_id: 6
          last_known_state: COMPLETE
          """,
                                            execution,
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
Example #15
0
    def testResolverInputsArtifacts(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(self._testdata_dir,
                         'pipeline_for_input_resolver_test.pbtxt'), pipeline)
        my_example_gen = pipeline.nodes[0].pipeline_node
        another_example_gen = pipeline.nodes[1].pipeline_node
        my_transform = pipeline.nodes[2].pipeline_node
        my_trainer = pipeline.nodes[3].pipeline_node

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        with metadata.Metadata(connection_config=connection_config) as m:
            # Publishes first ExampleGen with two output channels. `output_examples`
            # will be consumed by downstream Transform.
            output_example = types.Artifact(
                my_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            output_example.uri = 'my_examples_uri'
            side_examples = types.Artifact(
                my_example_gen.outputs.outputs['side_examples'].artifact_spec.
                type)
            side_examples.uri = 'side_examples_uri'
            contexts = context_lib.register_contexts_if_not_exists(
                m, my_example_gen.contexts)
            execution = execution_publish_utils.register_execution(
                m, my_example_gen.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'output_examples': [output_example],
                    'another_examples': [side_examples]
                })

            # Publishes second ExampleGen with one output channel with the same output
            # key as the first ExampleGen. However this is not consumed by downstream
            # nodes.
            another_output_example = types.Artifact(
                another_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            another_output_example.uri = 'another_examples_uri'
            contexts = context_lib.register_contexts_if_not_exists(
                m, another_example_gen.contexts)
            execution = execution_publish_utils.register_execution(
                m, another_example_gen.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'output_examples': [another_output_example],
                })

            # Gets inputs for transform. Should get back what the first ExampleGen
            # published in the `output_examples` channel.
            transform_inputs = inputs_utils.resolve_input_artifacts(
                m, my_transform.inputs)
            self.assertEqual(len(transform_inputs), 1)
            self.assertEqual(len(transform_inputs['examples']), 1)
            self.assertProtoPartiallyEquals(
                transform_inputs['examples'][0].mlmd_artifact,
                output_example.mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])

            # Tries to resolve inputs for trainer. As trainer also requires min_count
            # for both input channels (from example_gen and from transform) but we did
            # not publish anything from transform, it should return nothing.
            self.assertIsNone(
                inputs_utils.resolve_input_artifacts(m, my_trainer.inputs))
Example #16
0
    def run(
        self, mlmd_connection: metadata.Metadata,
        pipeline_node: pipeline_pb2.PipelineNode,
        pipeline_info: pipeline_pb2.PipelineInfo,
        pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec
    ) -> metadata_store_pb2.Execution:
        """Runs Importer specific logic.

    Args:
      mlmd_connection: ML metadata connection.
      pipeline_node: The specification of the node that this launcher lauches.
      pipeline_info: The information of the pipeline that this node runs in.
      pipeline_runtime_spec: The runtime information of the pipeline that this
        node runs in.

    Returns:
      The execution of the run.
    """
        logging.info('Running as an importer node.')
        with mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.register_contexts_if_not_exists(
                metadata_handler=m, node_contexts=pipeline_node.contexts)

            # 2. Resolves execution properties, please note that importers has no
            # input.
            exec_properties = inputs_utils.resolve_parameters(
                node_parameters=pipeline_node.parameters)

            # 3. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=pipeline_node.node_info.type,
                contexts=contexts,
                exec_properties=exec_properties)

            # 4. Generate output artifacts to represent the imported artifacts.
            output_spec = pipeline_node.outputs.outputs[
                importer_node.IMPORT_RESULT_KEY]
            properties = self._extract_proto_map(
                output_spec.artifact_spec.additional_properties)
            custom_properties = self._extract_proto_map(
                output_spec.artifact_spec.additional_custom_properties)
            output_artifact_class = types.Artifact(
                output_spec.artifact_spec.type).type
            output_artifacts = importer_node.generate_output_dict(
                metadata_handler=m,
                uri=str(exec_properties[importer_node.SOURCE_URI_KEY]),
                properties=properties,
                custom_properties=custom_properties,
                reimport=bool(
                    exec_properties[importer_node.REIMPORT_OPTION_KEY]),
                output_artifact_class=output_artifact_class,
                mlmd_artifact_type=output_spec.artifact_spec.type)

            # 5. Publish the output artifacts.
            execution_publish_utils.publish_succeeded_execution(
                metadata_handler=m,
                execution_id=execution.id,
                contexts=contexts,
                output_artifacts=output_artifacts)

            return execution
Example #17
0
 def testPublishSuccessfulExecution(self):
   with metadata.Metadata(connection_config=self._connection_config) as m:
     contexts = self._generate_contexts(m)
     execution_id = execution_publish_utils.register_execution(
         m, self._execution_type, contexts).id
     output_key = 'examples'
     output_example = standard_artifacts.Examples()
     executor_output = execution_result_pb2.ExecutorOutput()
     text_format.Parse(
         """
         uri: 'examples_uri'
         custom_properties {
           key: 'prop'
           value {int_value: 1}
         }
         """, executor_output.output_artifacts[output_key].artifacts.add())
     execution_publish_utils.publish_succeeded_execution(
         m, execution_id, contexts, {output_key: [output_example]},
         executor_output)
     [execution] = m.store.get_executions()
     self.assertProtoPartiallyEquals(
         """
         id: 1
         type_id: 3
         last_known_state: COMPLETE
         """,
         execution,
         ignored_fields=[
             'create_time_since_epoch', 'last_update_time_since_epoch'
         ])
     [artifact] = m.store.get_artifacts()
     self.assertProtoPartiallyEquals(
         """
         id: 1
         type_id: 4
         state: LIVE
         uri: 'examples_uri'
         custom_properties {
           key: 'prop'
           value {int_value: 1}
         }""",
         artifact,
         ignored_fields=[
             'create_time_since_epoch', 'last_update_time_since_epoch'
         ])
     [event] = m.store.get_events_by_execution_ids([execution.id])
     self.assertProtoPartiallyEquals(
         """
         artifact_id: 1
         execution_id: 1
         path {
           steps {
             key: 'examples'
           }
           steps {
             index: 0
           }
         }
         type: OUTPUT
         """,
         event,
         ignored_fields=['milliseconds_since_epoch'])
     # Verifies the context-execution edges are set up.
     self.assertCountEqual(
         [c.id for c in contexts],
         [c.id for c in m.store.get_contexts_by_execution(execution.id)])
     self.assertCountEqual(
         [c.id for c in contexts],
         [c.id for c in m.store.get_contexts_by_artifact(output_example.id)])
Example #18
0
    def testPublishSuccessExecutionExecutorEditedOutputDict(self):
        # There is one artifact in the system provided output_dict, while there are
        # two artifacts in executor output. We expect that two artifacts are
        # published.
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            execution_id = execution_publish_utils.register_execution(
                m, self._execution_type, contexts).id

            output_example = standard_artifacts.Examples()
            output_example.uri = '/original_path'

            executor_output = execution_result_pb2.ExecutorOutput()
            output_key = 'examples'
            text_format.Parse(
                """
          uri: '/original_path/subdir_1'
          custom_properties {
            key: 'prop'
            value {int_value: 1}
          }
          """, executor_output.output_artifacts[output_key].artifacts.add())
            text_format.Parse(
                """
          uri: '/original_path/subdir_2'
          custom_properties {
            key: 'prop'
            value {int_value: 2}
          }
          """, executor_output.output_artifacts[output_key].artifacts.add())

            output_dict = execution_publish_utils.publish_succeeded_execution(
                m, execution_id, contexts, {output_key: [output_example]},
                executor_output)
            [execution] = m.store.get_executions()
            self.assertProtoPartiallyEquals("""
          id: 1
          type_id: 3
          last_known_state: COMPLETE
          """,
                                            execution,
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
            artifacts = m.store.get_artifacts()
            self.assertLen(artifacts, 2)
            self.assertProtoPartiallyEquals("""
          id: 1
          type_id: 4
          state: LIVE
          uri: '/original_path/subdir_1'
          custom_properties {
            key: 'prop'
            value {int_value: 1}
          }""",
                                            artifacts[0],
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
            self.assertProtoPartiallyEquals("""
          id: 2
          type_id: 4
          state: LIVE
          uri: '/original_path/subdir_2'
          custom_properties {
            key: 'prop'
            value {int_value: 2}
          }""",
                                            artifacts[1],
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
            events = m.store.get_events_by_execution_ids([execution.id])
            self.assertLen(events, 2)
            self.assertProtoPartiallyEquals(
                """
          artifact_id: 1
          execution_id: 1
          path {
            steps {
              key: 'examples'
            }
            steps {
              index: 0
            }
          }
          type: OUTPUT
          """,
                events[0],
                ignored_fields=['milliseconds_since_epoch'])
            self.assertProtoPartiallyEquals(
                """
          artifact_id: 2
          execution_id: 1
          path {
            steps {
              key: 'examples'
            }
            steps {
              index: 1
            }
          }
          type: OUTPUT
          """,
                events[1],
                ignored_fields=['milliseconds_since_epoch'])
            # Verifies the context-execution edges are set up.
            self.assertCountEqual([c.id for c in contexts], [
                c.id for c in m.store.get_contexts_by_execution(execution.id)
            ])
            for artifact_list in output_dict.values():
                for output_example in artifact_list:
                    self.assertCountEqual([c.id for c in contexts], [
                        c.id for c in m.store.get_contexts_by_artifact(
                            output_example.id)
                    ])
    def test_resolver_task_scheduler(self):
        with self._mlmd_connection as m:
            # Publishes two models which will be consumed by downstream resolver.
            output_model_1 = types.Artifact(
                self._trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_1.uri = 'my_model_uri_1'

            output_model_2 = types.Artifact(
                self._trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_2.uri = 'my_model_uri_2'

            contexts = context_lib.prepare_contexts(m, self._trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, self._trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model_1, output_model_2],
                })

        task_queue = tq.TaskQueue()

        # Verify that resolver task is generated.
        [resolver_task] = test_utils.run_generator_and_test(
            test_case=self,
            mlmd_connection=self._mlmd_connection,
            generator_class=sptg.SyncPipelineTaskGenerator,
            pipeline=self._pipeline,
            task_queue=task_queue,
            use_task_queue=False,
            service_job_manager=None,
            num_initial_executions=1,
            num_tasks_generated=1,
            num_new_executions=1,
            num_active_executions=1,
            expected_exec_nodes=[self._resolver_node],
            ignore_update_node_state_tasks=True)

        with self._mlmd_connection as m:
            # Run resolver task scheduler and publish results.
            ts_result = resolver_task_scheduler.ResolverTaskScheduler(
                mlmd_handle=m, pipeline=self._pipeline,
                task=resolver_task).schedule()
            self.assertEqual(status_lib.Code.OK, ts_result.status.code)
            self.assertIsInstance(ts_result.output,
                                  task_scheduler.ResolverNodeOutput)
            self.assertCountEqual(
                ['resolved_model'],
                ts_result.output.resolved_input_artifacts.keys())
            models = ts_result.output.resolved_input_artifacts[
                'resolved_model']
            self.assertLen(models, 1)
            self.assertEqual('my_model_uri_2', models[0].mlmd_artifact.uri)
            tm._publish_execution_results(m, resolver_task, ts_result)

        # Verify resolver node output is input to the downstream consumer node.
        [consumer_task] = test_utils.run_generator_and_test(
            test_case=self,
            mlmd_connection=self._mlmd_connection,
            generator_class=sptg.SyncPipelineTaskGenerator,
            pipeline=self._pipeline,
            task_queue=task_queue,
            use_task_queue=False,
            service_job_manager=None,
            num_initial_executions=2,
            num_tasks_generated=1,
            num_new_executions=1,
            num_active_executions=1,
            expected_exec_nodes=[self._consumer_node],
            ignore_update_node_state_tasks=True)
        self.assertCountEqual(['resolved_model'],
                              consumer_task.input_artifacts.keys())
        input_models = consumer_task.input_artifacts['resolved_model']
        self.assertLen(input_models, 1)
        self.assertEqual('my_model_uri_2', input_models[0].mlmd_artifact.uri)
Example #20
0
 def testGetCachedOutputArtifacts(self):
     # Output artifacts that will be used by the first execution with the same
     # cache key.
     output_model_one = standard_artifacts.Model()
     output_model_one.uri = 'model_one'
     output_model_two = standard_artifacts.Model()
     output_model_two.uri = 'model_two'
     output_example_one = standard_artifacts.Examples()
     output_example_one.uri = 'example_one'
     # Output artifacts that will be used by the second execution with the same
     # cache key.
     output_model_three = standard_artifacts.Model()
     output_model_three.uri = 'model_three'
     output_model_four = standard_artifacts.Model()
     output_model_four.uri = 'model_four'
     output_example_two = standard_artifacts.Examples()
     output_example_two.uri = 'example_two'
     output_models_key = 'output_models'
     output_examples_key = 'output_examples'
     with metadata.Metadata(connection_config=self._connection_config) as m:
         cache_context = context_lib.register_context_if_not_exists(
             m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key')
         execution_one = execution_publish_utils.register_execution(
             m, metadata_store_pb2.ExecutionType(name='my_type'),
             [cache_context])
         execution_publish_utils.publish_succeeded_execution(
             m,
             execution_one.id, [cache_context],
             output_artifacts={
                 output_models_key: [output_model_one, output_model_two],
                 output_examples_key: [output_example_one]
             })
         execution_two = execution_publish_utils.register_execution(
             m, metadata_store_pb2.ExecutionType(name='my_type'),
             [cache_context])
         execution_publish_utils.publish_succeeded_execution(
             m,
             execution_two.id, [cache_context],
             output_artifacts={
                 output_models_key: [output_model_three, output_model_four],
                 output_examples_key: [output_example_two]
             })
         # The cached output got should be the artifacts produced by the most
         # recent execution under the given cache context.
         cached_output = cache_utils.get_cached_outputs(m, cache_context)
         self.assertLen(cached_output, 2)
         self.assertLen(cached_output[output_models_key], 2)
         self.assertLen(cached_output[output_examples_key], 1)
         self.assertProtoPartiallyEquals(
             cached_output[output_models_key][0].mlmd_artifact,
             output_model_three.mlmd_artifact,
             ignored_fields=[
                 'create_time_since_epoch', 'last_update_time_since_epoch'
             ])
         self.assertProtoPartiallyEquals(
             cached_output[output_models_key][1].mlmd_artifact,
             output_model_four.mlmd_artifact,
             ignored_fields=[
                 'create_time_since_epoch', 'last_update_time_since_epoch'
             ])
         self.assertProtoPartiallyEquals(
             cached_output[output_examples_key][0].mlmd_artifact,
             output_example_two.mlmd_artifact,
             ignored_fields=[
                 'create_time_since_epoch', 'last_update_time_since_epoch'
             ])
Example #21
0
    def run(
        self, mlmd_connection: metadata.Metadata,
        pipeline_node: pipeline_pb2.PipelineNode,
        pipeline_info: pipeline_pb2.PipelineInfo,
        pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec
    ) -> data_types.ExecutionInfo:
        """Runs Importer specific logic.

    Args:
      mlmd_connection: ML metadata connection.
      pipeline_node: The specification of the node that this launcher lauches.
      pipeline_info: The information of the pipeline that this node runs in.
      pipeline_runtime_spec: The runtime information of the pipeline that this
        node runs in.

    Returns:
      The execution of the run.
    """
        logging.info('Running as an importer node.')
        with mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.prepare_contexts(
                metadata_handler=m, node_contexts=pipeline_node.contexts)

            # 2. Resolves execution properties, please note that importers has no
            # input.
            exec_properties = data_types_utils.build_parsed_value_dict(
                inputs_utils.resolve_parameters_with_schema(
                    node_parameters=pipeline_node.parameters))

            # 3. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=pipeline_node.node_info.type,
                contexts=contexts,
                exec_properties=exec_properties)

            # 4. Generate output artifacts to represent the imported artifacts.
            output_spec = pipeline_node.outputs.outputs[
                importer.IMPORT_RESULT_KEY]
            properties = self._extract_proto_map(
                output_spec.artifact_spec.additional_properties)
            custom_properties = self._extract_proto_map(
                output_spec.artifact_spec.additional_custom_properties)
            output_artifact_class = types.Artifact(
                output_spec.artifact_spec.type).type
            output_artifacts = importer.generate_output_dict(
                metadata_handler=m,
                uri=str(exec_properties[importer.SOURCE_URI_KEY]),
                properties=properties,
                custom_properties=custom_properties,
                reimport=bool(exec_properties[importer.REIMPORT_OPTION_KEY]),
                output_artifact_class=output_artifact_class,
                mlmd_artifact_type=output_spec.artifact_spec.type)

            result = data_types.ExecutionInfo(execution_id=execution.id,
                                              input_dict={},
                                              output_dict=output_artifacts,
                                              exec_properties=exec_properties,
                                              pipeline_node=pipeline_node,
                                              pipeline_info=pipeline_info)

            # TODO(b/182316162): consider let the launcher level do the publish
            # for system nodes. So that the version taging logic doesn't need to be
            # handled per system node.
            outputs_utils.tag_output_artifacts_with_version(result.output_dict)

            # 5. Publish the output artifacts. If artifacts are reimported, the
            # execution is published as CACHED. Otherwise it is published as COMPLETE.
            if _is_artifact_reimported(output_artifacts):
                execution_publish_utils.publish_cached_execution(
                    metadata_handler=m,
                    contexts=contexts,
                    execution_id=execution.id,
                    output_artifacts=output_artifacts)

            else:
                execution_publish_utils.publish_succeeded_execution(
                    metadata_handler=m,
                    execution_id=execution.id,
                    contexts=contexts,
                    output_artifacts=output_artifacts)

            return result