예제 #1
0
    def fakeUpstreamOutputs(mlmd_connection: metadata.Metadata,
                            example_gen: pipeline_pb2.PipelineNode,
                            transform: pipeline_pb2.PipelineNode):

        with mlmd_connection as m:
            if example_gen:
                # Publishes ExampleGen output.
                output_example = types.Artifact(
                    example_gen.outputs.outputs['output_examples'].
                    artifact_spec.type)
                output_example.uri = 'my_examples_uri'
                contexts = context_lib.register_contexts_if_not_exists(
                    m, example_gen.contexts)
                execution = execution_publish_utils.register_execution(
                    m, example_gen.node_info.type, contexts)
                execution_publish_utils.publish_succeeded_execution(
                    m, execution.id, contexts, {
                        'output_examples': [output_example],
                    })

            if transform:
                # Publishes Transform output.
                output_transform_graph = types.Artifact(
                    transform.outputs.outputs['transform_graph'].artifact_spec.
                    type)
                output_example.uri = 'my_transform_graph_uri'
                contexts = context_lib.register_contexts_if_not_exists(
                    m, transform.contexts)
                execution = execution_publish_utils.register_execution(
                    m, transform.node_info.type, contexts)
                execution_publish_utils.publish_succeeded_execution(
                    m, execution.id, contexts, {
                        'transform_graph': [output_transform_graph],
                    })
예제 #2
0
def generate_resolved_info(metadata_handler: metadata.Metadata,
                           node: pipeline_pb2.PipelineNode) -> ResolvedInfo:
    """Returns a `ResolvedInfo` object for executing the node.

  Args:
    metadata_handler: A handler to access MLMD db.
    node: The pipeline node for which to generate.

  Returns:
    A `ResolvedInfo` with input resolutions.
  """
    # Register node contexts.
    contexts = context_lib.register_contexts_if_not_exists(
        metadata_handler=metadata_handler, node_contexts=node.contexts)

    # Resolve execution properties.
    exec_properties = inputs_utils.resolve_parameters(
        node_parameters=node.parameters)

    # Resolve inputs.
    input_artifacts = inputs_utils.resolve_input_artifacts(
        metadata_handler=metadata_handler, node_inputs=node.inputs)

    return ResolvedInfo(contexts=contexts,
                        exec_properties=exec_properties,
                        input_artifacts=input_artifacts)
예제 #3
0
    def testRegisterContexts(self):
        node_contexts = pipeline_pb2.NodeContexts()
        self.load_proto_from_text('node_context_spec.pbtxt', node_contexts)
        with metadata.Metadata(connection_config=self._connection_config) as m:
            context_lib.register_contexts_if_not_exists(
                metadata_handler=m, node_contexts=node_contexts)
            # Duplicated call should succeed.
            contexts = context_lib.register_contexts_if_not_exists(
                metadata_handler=m, node_contexts=node_contexts)

            self.assertProtoEquals(
                """
          id: 1
          name: 'my_context_type_one'
          """, m.store.get_context_type('my_context_type_one'))
            self.assertProtoEquals(
                """
          id: 2
          name: 'my_context_type_two'
          """, m.store.get_context_type('my_context_type_two'))
            self.assertEqual(
                contexts[0],
                m.store.get_context_by_type_and_name('my_context_type_one',
                                                     'my_context_one'))
            self.assertEqual(
                contexts[1],
                m.store.get_context_by_type_and_name('my_context_type_one',
                                                     'my_context_two'))
            self.assertEqual(
                contexts[2],
                m.store.get_context_by_type_and_name('my_context_type_two',
                                                     'my_context_three'))
            self.assertEqual(
                contexts[0].custom_properties['property_a'].int_value, 1)
            self.assertEqual(
                contexts[1].custom_properties['property_a'].int_value, 2)
            self.assertEqual(
                contexts[2].custom_properties['property_a'].int_value, 3)
            self.assertEqual(
                contexts[2].custom_properties['property_b'].string_value, '4')
예제 #4
0
def fake_trainer_output(mlmd_connection, trainer, execution=None):
    """Writes fake trainer output and execution to MLMD."""
    with mlmd_connection as m:
        output_trainer_model = types.Artifact(
            trainer.outputs.outputs['model'].artifact_spec.type)
        output_trainer_model.uri = 'my_trainer_model_uri'
        contexts = context_lib.register_contexts_if_not_exists(
            m, trainer.contexts)
        if not execution:
            execution = execution_publish_utils.register_execution(
                m, trainer.node_info.type, contexts)
        execution_publish_utils.publish_succeeded_execution(
            m, execution.id, contexts, {
                'model': [output_trainer_model],
            })
예제 #5
0
def fake_transform_output(mlmd_connection, transform, execution=None):
    """Writes fake transform output and execution to MLMD."""
    with mlmd_connection as m:
        output_transform_graph = types.Artifact(
            transform.outputs.outputs['transform_graph'].artifact_spec.type)
        output_transform_graph.uri = 'my_transform_graph_uri'
        contexts = context_lib.register_contexts_if_not_exists(
            m, transform.contexts)
        if not execution:
            execution = execution_publish_utils.register_execution(
                m, transform.node_info.type, contexts)
        execution_publish_utils.publish_succeeded_execution(
            m, execution.id, contexts, {
                'transform_graph': [output_transform_graph],
            })
예제 #6
0
def fake_example_gen_run(mlmd_connection, example_gen, span, version):
    """Writes fake example_gen output and successful execution to MLMD."""
    with mlmd_connection as m:
        output_example = types.Artifact(
            example_gen.outputs.outputs['output_examples'].artifact_spec.type)
        output_example.set_int_custom_property('span', span)
        output_example.set_int_custom_property('version', version)
        output_example.uri = 'my_examples_uri'
        contexts = context_lib.register_contexts_if_not_exists(
            m, example_gen.contexts)
        execution = execution_publish_utils.register_execution(
            m, example_gen.node_info.type, contexts)
        execution_publish_utils.publish_succeeded_execution(
            m, execution.id, contexts, {
                'output_examples': [output_example],
            })
예제 #7
0
    def testResolverWithResolverPolicy(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(self._testdata_dir,
                         'pipeline_for_input_resolver_test.pbtxt'), pipeline)
        my_example_gen = pipeline.nodes[0].pipeline_node
        my_transform = pipeline.nodes[2].pipeline_node

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        with metadata.Metadata(connection_config=connection_config) as m:
            # Publishes first ExampleGen with two output channels. `output_examples`
            # will be consumed by downstream Transform.
            output_example_1 = types.Artifact(
                my_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            output_example_1.uri = 'my_examples_uri_1'

            output_example_2 = types.Artifact(
                my_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            output_example_2.uri = 'my_examples_uri_2'

            contexts = context_lib.register_contexts_if_not_exists(
                m, my_example_gen.contexts)
            execution = execution_publish_utils.register_execution(
                m, my_example_gen.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'output_examples': [output_example_1, output_example_2],
                })

            my_transform.inputs.resolver_config.resolver_policy = (
                pipeline_pb2.ResolverConfig.LATEST_ARTIFACT)

            # Gets inputs for transform. Should get back what the first ExampleGen
            # published in the `output_examples` channel.
            transform_inputs = inputs_utils.resolve_input_artifacts(
                m, my_transform.inputs)
            self.assertEqual(len(transform_inputs), 1)
            self.assertEqual(len(transform_inputs['examples']), 1)
            self.assertProtoPartiallyEquals(
                transform_inputs['examples'][0].mlmd_artifact,
                output_example_2.mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
예제 #8
0
파일: launcher_test.py 프로젝트: lre/tfx
 def fakeExampleGenOutput(mlmd_connection: metadata.Metadata,
                          example_gen: pipeline_pb2.PipelineNode, span: int,
                          version: int):
   with mlmd_connection as m:
     output_example = types.Artifact(
         example_gen.outputs.outputs['output_examples'].artifact_spec.type)
     output_example.set_int_custom_property('span', span)
     output_example.set_int_custom_property('version', version)
     output_example.uri = 'my_examples_uri'
     contexts = context_lib.register_contexts_if_not_exists(
         m, example_gen.contexts)
     execution = execution_publish_utils.register_execution(
         m, example_gen.node_info.type, contexts)
     execution_publish_utils.publish_succeeded_execution(
         m, execution.id, contexts, {
             'output_examples': [output_example],
         })
예제 #9
0
 def _generate_contexts(self, metadata_handler):
   context_spec = pipeline_pb2.NodeContexts()
   text_format.Parse(
       """
       contexts {
         type {name: 'pipeline_context'}
         name {
           field_value {string_value: 'my_pipeline'}
         }
       }
       contexts {
         type {name: 'component_context'}
         name {
           field_value {string_value: 'my_component'}
         }
       }""", context_spec)
   return context_lib.register_contexts_if_not_exists(metadata_handler,
                                                      context_spec)
예제 #10
0
    def testResolveInputArtifactsOutputKeyUnset(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(
                self._testdata_dir,
                'pipeline_for_input_resolver_test_output_key_unset.pbtxt'),
            pipeline)
        my_trainer = pipeline.nodes[0].pipeline_node
        my_pusher = pipeline.nodes[1].pipeline_node

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        with metadata.Metadata(connection_config=connection_config) as m:
            # Publishes Trainer with one output channels. `output_model`
            # will be consumed by the Pusher in the different run.
            output_model = types.Artifact(
                my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model.uri = 'my_output_model_uri'
            contexts = context_lib.register_contexts_if_not_exists(
                m, my_trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, my_trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model],
                })
            # Gets inputs for pusher. Should get back what the first Model
            # published in the `output_model` channel.
            pusher_inputs = inputs_utils.resolve_input_artifacts(
                m, my_pusher.inputs)
            self.assertEqual(len(pusher_inputs), 1)
            self.assertEqual(len(pusher_inputs['model']), 1)
            self.assertProtoPartiallyEquals(
                output_model.mlmd_artifact,
                pusher_inputs['model'][0].mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
예제 #11
0
파일: launcher.py 프로젝트: hamzamaiot/tfx
    def _prepare_execution(self) -> _PrepareExecutionResult:
        """Prepares inputs, outputs and execution properties for actual execution."""
        # TODO(b/150979622): handle the edge case that the component get evicted
        # between successful pushlish and stateful working dir being clean up.
        # Otherwise following retries will keep failing because of duplicate
        # publishes.
        with self._mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.register_contexts_if_not_exists(
                metadata_handler=m, node_contexts=self._pipeline_node.contexts)

            # 2. Resolves inputs an execution properties.
            exec_properties = inputs_utils.resolve_parameters(
                node_parameters=self._pipeline_node.parameters)
            input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=self._pipeline_node.inputs)
            # 3. If not all required inputs are met. Return ExecutionInfo with
            # is_execution_needed being false. No publish will happen so down stream
            # nodes won't be triggered.
            if input_artifacts is None:
                return _PrepareExecutionResult(
                    execution_info=data_types.ExecutionInfo(),
                    contexts=contexts,
                    is_execution_needed=False)

            # 4. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=self._pipeline_node.node_info.type,
                contexts=contexts,
                input_artifacts=input_artifacts,
                exec_properties=exec_properties)

            # 5. Resolve output
            output_artifacts = self._output_resolver.generate_output_artifacts(
                execution.id)

        # If there is a custom driver, runs it.
        if self._driver_operator:
            driver_output = self._driver_operator.run_driver(
                data_types.ExecutionInfo(
                    input_dict=input_artifacts,
                    output_dict=output_artifacts,
                    exec_properties=exec_properties,
                    execution_output_uri=self._output_resolver.
                    get_driver_output_uri()))
            self._update_with_driver_output(driver_output, exec_properties,
                                            output_artifacts)

        # We reconnect to MLMD here because the custom driver closes MLMD connection
        # on returning.
        with self._mlmd_connection as m:
            # 6. Check cached result
            cache_context = cache_utils.get_cache_context(
                metadata_handler=m,
                pipeline_node=self._pipeline_node,
                pipeline_info=self._pipeline_info,
                input_artifacts=input_artifacts,
                output_artifacts=output_artifacts,
                parameters=exec_properties)
            contexts.append(cache_context)
            cached_outputs = cache_utils.get_cached_outputs(
                metadata_handler=m, cache_context=cache_context)

            # 7. Should cache be used?
            if (self._pipeline_node.execution_options.caching_options.
                    enable_cache and cached_outputs):
                # Publishes cache result
                execution_publish_utils.publish_cached_execution(
                    metadata_handler=m,
                    contexts=contexts,
                    execution_id=execution.id,
                    output_artifacts=cached_outputs)
                return _PrepareExecutionResult(
                    execution_info=data_types.ExecutionInfo(
                        execution_id=execution.id),
                    execution_metadata=execution,
                    contexts=contexts,
                    is_execution_needed=False)

            pipeline_run_id = (self._pipeline_runtime_spec.pipeline_run_id.
                               field_value.string_value)

            # 8. Going to trigger executor.
            return _PrepareExecutionResult(
                execution_info=data_types.ExecutionInfo(
                    execution_id=execution.id,
                    input_dict=input_artifacts,
                    output_dict=output_artifacts,
                    exec_properties=exec_properties,
                    execution_output_uri=self._output_resolver.
                    get_executor_output_uri(execution.id),
                    stateful_working_dir=(self._output_resolver.
                                          get_stateful_working_directory()),
                    tmp_dir=self._output_resolver.make_tmp_dir(execution.id),
                    pipeline_node=self._pipeline_node,
                    pipeline_info=self._pipeline_info,
                    pipeline_run_id=pipeline_run_id),
                execution_metadata=execution,
                contexts=contexts,
                is_execution_needed=True)
예제 #12
0
    def testResolverInputsArtifacts(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(self._testdata_dir,
                         'pipeline_for_input_resolver_test.pbtxt'), pipeline)
        my_example_gen = pipeline.nodes[0].pipeline_node
        another_example_gen = pipeline.nodes[1].pipeline_node
        my_transform = pipeline.nodes[2].pipeline_node
        my_trainer = pipeline.nodes[3].pipeline_node

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        with metadata.Metadata(connection_config=connection_config) as m:
            # Publishes first ExampleGen with two output channels. `output_examples`
            # will be consumed by downstream Transform.
            output_example = types.Artifact(
                my_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            output_example.uri = 'my_examples_uri'
            side_examples = types.Artifact(
                my_example_gen.outputs.outputs['side_examples'].artifact_spec.
                type)
            side_examples.uri = 'side_examples_uri'
            contexts = context_lib.register_contexts_if_not_exists(
                m, my_example_gen.contexts)
            execution = execution_publish_utils.register_execution(
                m, my_example_gen.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'output_examples': [output_example],
                    'another_examples': [side_examples]
                })

            # Publishes second ExampleGen with one output channel with the same output
            # key as the first ExampleGen. However this is not consumed by downstream
            # nodes.
            another_output_example = types.Artifact(
                another_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            another_output_example.uri = 'another_examples_uri'
            contexts = context_lib.register_contexts_if_not_exists(
                m, another_example_gen.contexts)
            execution = execution_publish_utils.register_execution(
                m, another_example_gen.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'output_examples': [another_output_example],
                })

            # Gets inputs for transform. Should get back what the first ExampleGen
            # published in the `output_examples` channel.
            transform_inputs = inputs_utils.resolve_input_artifacts(
                m, my_transform.inputs)
            self.assertEqual(len(transform_inputs), 1)
            self.assertEqual(len(transform_inputs['examples']), 1)
            self.assertProtoPartiallyEquals(
                transform_inputs['examples'][0].mlmd_artifact,
                output_example.mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])

            # Tries to resolve inputs for trainer. As trainer also requires min_count
            # for both input channels (from example_gen and from transform) but we did
            # not publish anything from transform, it should return nothing.
            self.assertIsNone(
                inputs_utils.resolve_input_artifacts(m, my_trainer.inputs))
예제 #13
0
    def run(
        self, mlmd_connection: metadata.Metadata,
        pipeline_node: pipeline_pb2.PipelineNode,
        pipeline_info: pipeline_pb2.PipelineInfo,
        pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec
    ) -> metadata_store_pb2.Execution:
        """Runs Importer specific logic.

    Args:
      mlmd_connection: ML metadata connection.
      pipeline_node: The specification of the node that this launcher lauches.
      pipeline_info: The information of the pipeline that this node runs in.
      pipeline_runtime_spec: The runtime information of the pipeline that this
        node runs in.

    Returns:
      The execution of the run.
    """
        logging.info('Running as an importer node.')
        with mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.register_contexts_if_not_exists(
                metadata_handler=m, node_contexts=pipeline_node.contexts)

            # 2. Resolves execution properties, please note that importers has no
            # input.
            exec_properties = inputs_utils.resolve_parameters(
                node_parameters=pipeline_node.parameters)

            # 3. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=pipeline_node.node_info.type,
                contexts=contexts,
                exec_properties=exec_properties)

            # 4. Generate output artifacts to represent the imported artifacts.
            output_spec = pipeline_node.outputs.outputs[
                importer_node.IMPORT_RESULT_KEY]
            properties = self._extract_proto_map(
                output_spec.artifact_spec.additional_properties)
            custom_properties = self._extract_proto_map(
                output_spec.artifact_spec.additional_custom_properties)
            output_artifact_class = types.Artifact(
                output_spec.artifact_spec.type).type
            output_artifacts = importer_node.generate_output_dict(
                metadata_handler=m,
                uri=str(exec_properties[importer_node.SOURCE_URI_KEY]),
                properties=properties,
                custom_properties=custom_properties,
                reimport=bool(
                    exec_properties[importer_node.REIMPORT_OPTION_KEY]),
                output_artifact_class=output_artifact_class,
                mlmd_artifact_type=output_spec.artifact_spec.type)

            # 5. Publish the output artifacts.
            execution_publish_utils.publish_succeeded_execution(
                metadata_handler=m,
                execution_id=execution.id,
                contexts=contexts,
                output_artifacts=output_artifacts)

            return execution
예제 #14
0
    def testSuccess(self):
        with self._mlmd_connection as m:
            # Publishes two models which will be consumed by downstream resolver.
            output_model_1 = types.Artifact(
                self._my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_1.uri = 'my_model_uri_1'

            output_model_2 = types.Artifact(
                self._my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_2.uri = 'my_model_uri_2'

            contexts = context_lib.register_contexts_if_not_exists(
                m, self._my_trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, self._my_trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model_1, output_model_2],
                })

        handler = resolver_node_handler.ResolverNodeHandler()
        execution_metadata = handler.run(
            mlmd_connection=self._mlmd_connection,
            pipeline_node=self._resolver_node,
            pipeline_info=self._pipeline_info,
            pipeline_runtime_spec=self._pipeline_runtime_spec)

        with self._mlmd_connection as m:
            # There is no way to directly verify the output artifact of the resolver
            # So here a fake downstream component is created which listens to the
            # resolver's output and we verify its input.
            down_stream_node = text_format.Parse(
                """
        inputs {
          inputs {
            key: "input_models"
            value {
              channels {
                producer_node_query {
                  id: "my_resolver"
                }
                context_queries {
                  type {
                    name: "pipeline"
                  }
                  name {
                    field_value {
                      string_value: "my_pipeline"
                    }
                  }
                }
                context_queries {
                  type {
                    name: "component"
                  }
                  name {
                    field_value {
                      string_value: "my_resolver"
                    }
                  }
                }
                artifact_query {
                  type {
                    name: "Model"
                  }
                }
                output_key: "models"
              }
              min_count: 1
            }
          }
        }
        upstream_nodes: "my_resolver"
        """, pipeline_pb2.PipelineNode())
            downstream_input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=down_stream_node.inputs)
            downstream_input_model = downstream_input_artifacts['input_models']
            self.assertLen(downstream_input_model, 1)
            self.assertProtoPartiallyEquals(
                """
          id: 2
          type_id: 5
          uri: "my_model_uri_2"
          state: LIVE""",
                downstream_input_model[0].mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
            [execution] = m.store.get_executions_by_id([execution_metadata.id])

            self.assertProtoPartiallyEquals("""
          id: 2
          type_id: 6
          last_known_state: COMPLETE
          """,
                                            execution,
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])