Ejemplo n.º 1
0
    def testResolveInputsArtifacts(self):
        pipeline = self.load_pipeline_proto(
            'pipeline_for_input_resolver_test.pbtxt')
        my_example_gen = pipeline.nodes[0].pipeline_node
        another_example_gen = pipeline.nodes[1].pipeline_node
        my_transform = pipeline.nodes[2].pipeline_node
        my_trainer = pipeline.nodes[3].pipeline_node

        with self.get_metadata() as m:
            # Publishes first ExampleGen with two output channels. `output_examples`
            # will be consumed by downstream Transform.
            output_example = self.make_examples(uri='my_examples_uri')
            side_examples = self.make_examples(uri='side_examples_uri')
            output_artifacts = self.fake_execute(m,
                                                 my_example_gen,
                                                 input_map=None,
                                                 output_map={
                                                     'output_examples':
                                                     [output_example],
                                                     'another_examples':
                                                     [side_examples]
                                                 })
            output_example = output_artifacts['output_examples'][0]
            side_examples = output_artifacts['another_examples'][0]

            # Publishes second ExampleGen with one output channel with the same output
            # key as the first ExampleGen. However this is not consumed by downstream
            # nodes.
            another_output_example = self.make_examples(
                uri='another_examples_uri')
            self.fake_execute(
                m,
                another_example_gen,
                input_map=None,
                output_map={'output_examples': [another_output_example]})

            # Gets inputs for transform. Should get back what the first ExampleGen
            # published in the `output_examples` channel.
            transform_inputs = inputs_utils.resolve_input_artifacts(
                m, my_transform.inputs)
            self.assertArtifactMapEqual({'examples': [output_example]},
                                        transform_inputs)

            # Tries to resolve inputs for trainer. As trainer also requires min_count
            # for both input channels (from example_gen and from transform) but we did
            # not publish anything from transform, it should return nothing.
            self.assertIsNone(
                inputs_utils.resolve_input_artifacts(m, my_trainer.inputs))

            # Tries to resolve inputs for transform after adding a new context query
            # to the input spec that refers to a non-existent context. Inputs cannot
            # be resolved in this case.
            context_query = my_transform.inputs.inputs['examples'].channels[
                0].context_queries.add()
            context_query.type.name = 'non_existent_context'
            context_query.name.field_value.string_value = 'non_existent_context'
            transform_inputs = inputs_utils.resolve_input_artifacts(
                m, my_transform.inputs)
            self.assertIsNone(transform_inputs)
Ejemplo n.º 2
0
    def testResolveInputArtifacts_NonDictArg(self):
        self._setup_pipeline_for_input_resolver_test()
        self._append_resolver_step(self._my_transform, DuplicateOp)
        self._append_resolver_step(self._my_transform, IdentityStrategy)

        with self.assertRaisesRegex(TypeError, 'Invalid argument type'):
            inputs_utils.resolve_input_artifacts(self._metadata_handler,
                                                 self._my_transform.inputs)
Ejemplo n.º 3
0
    def testResolveInputArtifacts_NonDictOutput(self):
        self._setup_pipeline_for_input_resolver_test()
        self._append_resolver_step(self._my_transform, BadOutputStrategy)

        with self.assertRaisesRegex(TypeError,
                                    'Invalid input resolution result'):
            inputs_utils.resolve_input_artifacts(self._metadata_handler,
                                                 self._my_transform.inputs)
Ejemplo n.º 4
0
    def testResolverWithLatestArtifactStrategy(self):
        pipeline = self.load_pipeline_proto(
            'pipeline_for_input_resolver_test.pbtxt')
        my_example_gen = pipeline.nodes[0].pipeline_node
        my_transform = pipeline.nodes[2].pipeline_node

        with self.get_metadata() as m:
            # Publishes first ExampleGen with two output channels. `output_examples`
            # will be consumed by downstream Transform.
            output_example_1 = self.make_examples(uri='my_examples_uri_1')
            output_example_2 = self.make_examples(uri='my_examples_uri_2')
            output_artifacts = self.fake_execute(
                m,
                my_example_gen,
                input_map=None,
                output_map={
                    'output_examples': [output_example_1, output_example_2]
                })
            output_example_1 = output_artifacts['output_examples'][0]
            output_example_2 = output_artifacts['output_examples'][1]

            transform_resolver = (
                my_transform.inputs.resolver_config.resolver_steps.add())
            transform_resolver.class_path = (
                'tfx.dsl.input_resolution.strategies.latest_artifact_strategy'
                '.LatestArtifactStrategy')
            transform_resolver.config_json = '{}'

            # Gets inputs for transform. Should get back what the first ExampleGen
            # published in the `output_examples` channel.
            transform_inputs = inputs_utils.resolve_input_artifacts(
                m, my_transform.inputs)
            self.assertArtifactMapEqual({'examples': [output_example_2]},
                                        transform_inputs)
Ejemplo n.º 5
0
    def testResolveInputArtifacts_SkippingStrategy(self):
        self._setup_pipeline_for_input_resolver_test()
        self._append_resolver_step(self._my_transform, SkippingStrategy)

        result = inputs_utils.resolve_input_artifacts(
            self._metadata_handler, self._my_transform.inputs)
        self.assertIsNone(result)
Ejemplo n.º 6
0
    def run(self, input_dict: Dict[Text, List[types.Artifact]],
            output_dict: Dict[Text, List[types.Artifact]],
            exec_properties: Dict[Text,
                                  Any]) -> driver_output_pb2.DriverOutput:
        # Fake a constant span number, which, on prod, is usually calculated based
        # on date.
        span = 2
        with self._mlmd_connection as m:
            previous_output = inputs_utils.resolve_input_artifacts(
                m, self._self_output)

            # Version should be the max of existing version + 1 if span exists,
            # otherwise 0.
            version = 0
            if previous_output:
                version = max([
                    artifact.get_int_custom_property('version')
                    for artifact in previous_output['examples']
                    if artifact.get_int_custom_property('span') == span
                ] or [-1]) + 1

        output_example = copy.deepcopy(
            output_dict['output_examples'][0].mlmd_artifact)
        output_example.custom_properties['span'].int_value = span
        output_example.custom_properties['version'].int_value = version
        result = driver_output_pb2.DriverOutput()
        result.output_artifacts['output_examples'].artifacts.append(
            output_example)

        result.exec_properties['span'].int_value = span
        result.exec_properties['version'].int_value = version
        return result
Ejemplo n.º 7
0
def generate_resolved_info(metadata_handler: metadata.Metadata,
                           node: pipeline_pb2.PipelineNode) -> ResolvedInfo:
    """Returns a `ResolvedInfo` object for executing the node.

  Args:
    metadata_handler: A handler to access MLMD db.
    node: The pipeline node for which to generate.

  Returns:
    A `ResolvedInfo` with input resolutions.
  """
    # Register node contexts.
    contexts = context_lib.register_contexts_if_not_exists(
        metadata_handler=metadata_handler, node_contexts=node.contexts)

    # Resolve execution properties.
    exec_properties = inputs_utils.resolve_parameters(
        node_parameters=node.parameters)

    # Resolve inputs.
    input_artifacts = inputs_utils.resolve_input_artifacts(
        metadata_handler=metadata_handler, node_inputs=node.inputs)

    return ResolvedInfo(contexts=contexts,
                        exec_properties=exec_properties,
                        input_artifacts=input_artifacts)
Ejemplo n.º 8
0
    def testResolveInputArtifacts_MixedStrategyAndOp(self):
        self._setup_pipeline_for_input_resolver_test()
        self._append_resolver_step(self._my_transform, IdentityStrategy)
        self._append_resolver_step(self._my_transform, IdentityOp)

        result = inputs_utils.resolve_input_artifacts(
            self._metadata_handler, self._my_transform.inputs)
        self.assertArtifactMapEqual({'examples': self._examples}, result)
Ejemplo n.º 9
0
    def testLatestArtifacts_withInputKeys(self):
        pipeline = self.load_pipeline_proto(
            'pipeline_for_input_resolver_test.pbtxt')
        my_example_gen = pipeline.nodes[0].pipeline_node
        my_transform = pipeline.nodes[2].pipeline_node
        my_trainer = pipeline.nodes[3].pipeline_node

        # Use LatestArtifactStrategy for TransformGraph only.
        resolver = my_trainer.inputs.resolver_config.resolver_steps.add()
        resolver.class_path = (
            'tfx.dsl.input_resolution.strategies.latest_artifact_strategy'
            '.LatestArtifactStrategy')
        resolver.config_json = '{}'
        resolver.input_keys.append('transform_graph')

        with self.get_metadata() as m:
            ex1 = self.make_examples(uri='examples/1')
            ex2 = self.make_examples(uri='examples/2')
            tf1 = self.make_transform_graph(uri='transform_graph/1')
            tf2 = self.make_transform_graph(uri='transform_graph/2')
            output_artifacts = self.fake_execute(
                m,
                my_example_gen,
                input_map=None,
                output_map={'output_examples': [ex1]})
            ex1 = output_artifacts['output_examples'][0]
            output_artifacts = self.fake_execute(
                m,
                my_example_gen,
                input_map=None,
                output_map={'output_examples': [ex2]})
            ex2 = output_artifacts['output_examples'][0]
            output_artifacts = self.fake_execute(
                m,
                my_transform,
                input_map={'examples': [ex1]},
                output_map={'transform_graph': [tf1]})
            tf1 = output_artifacts['transform_graph'][0]
            output_artifacts = self.fake_execute(
                m,
                my_transform,
                input_map={'examples': [ex2]},
                output_map={'transform_graph': [tf2]})
            tf2 = output_artifacts['transform_graph'][0]
            result = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=my_trainer.inputs)

        # "examples" input channel doesn't go through the resolver and its order is
        # undeterministic. Sort artifacts by ID for testing convenience.
        result['examples'].sort(key=lambda a: a.id)

        self.assertArtifactMapEqual(
            {
                'examples': [ex1, ex2],
                'transform_graph': [tf2]
            }, result)
Ejemplo n.º 10
0
    def run(
        self, mlmd_connection: metadata.Metadata,
        pipeline_node: pipeline_pb2.PipelineNode,
        pipeline_info: pipeline_pb2.PipelineInfo,
        pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec
    ) -> data_types.ExecutionInfo:
        """Runs Resolver specific logic.

    Args:
      mlmd_connection: ML metadata connection.
      pipeline_node: The specification of the node that this launcher lauches.
      pipeline_info: The information of the pipeline that this node runs in.
      pipeline_runtime_spec: The runtime information of the pipeline that this
        node runs in.

    Returns:
      The execution of the run.
    """
        logging.info('Running as an resolver node.')
        with mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.prepare_contexts(
                metadata_handler=m, node_contexts=pipeline_node.contexts)

            # 2. Resolves inputs an execution properties.
            exec_properties = inputs_utils.resolve_parameters(
                node_parameters=pipeline_node.parameters)
            input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=pipeline_node.inputs)

            # 3. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=pipeline_node.node_info.type,
                contexts=contexts,
                exec_properties=exec_properties)

            # 4. Publish the execution as a cached execution with
            # resolved input artifact as the output artifacts.
            execution_publish_utils.publish_internal_execution(
                metadata_handler=m,
                contexts=contexts,
                execution_id=execution.id,
                output_artifacts=input_artifacts)

            return data_types.ExecutionInfo(execution_id=execution.id,
                                            input_dict=input_artifacts,
                                            output_dict=input_artifacts,
                                            exec_properties=exec_properties,
                                            pipeline_node=pipeline_node,
                                            pipeline_info=pipeline_info)
Ejemplo n.º 11
0
    def testResolverWithResolverPolicy(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(self._testdata_dir,
                         'pipeline_for_input_resolver_test.pbtxt'), pipeline)
        my_example_gen = pipeline.nodes[0].pipeline_node
        my_transform = pipeline.nodes[2].pipeline_node

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        with metadata.Metadata(connection_config=connection_config) as m:
            # Publishes first ExampleGen with two output channels. `output_examples`
            # will be consumed by downstream Transform.
            output_example_1 = types.Artifact(
                my_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            output_example_1.uri = 'my_examples_uri_1'

            output_example_2 = types.Artifact(
                my_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            output_example_2.uri = 'my_examples_uri_2'

            contexts = context_lib.register_contexts_if_not_exists(
                m, my_example_gen.contexts)
            execution = execution_publish_utils.register_execution(
                m, my_example_gen.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'output_examples': [output_example_1, output_example_2],
                })

            my_transform.inputs.resolver_config.resolver_policy = (
                pipeline_pb2.ResolverConfig.LATEST_ARTIFACT)

            # Gets inputs for transform. Should get back what the first ExampleGen
            # published in the `output_examples` channel.
            transform_inputs = inputs_utils.resolve_input_artifacts(
                m, my_transform.inputs)
            self.assertEqual(len(transform_inputs), 1)
            self.assertEqual(len(transform_inputs['examples']), 1)
            self.assertProtoPartiallyEquals(
                transform_inputs['examples'][0].mlmd_artifact,
                output_example_2.mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
Ejemplo n.º 12
0
    def testLatestUnprocessedArtifacts_NoneIfEverythingProcessed(self):
        pipeline = self.load_pipeline_proto(
            'pipeline_for_input_resolver_test.pbtxt')
        my_example_gen = pipeline.nodes[0].pipeline_node
        my_transform = pipeline.nodes[2].pipeline_node

        resolver1 = my_transform.inputs.resolver_config.resolver_steps.add()
        resolver1.class_path = (
            'tfx.dsl.resolvers.unprocessed_artifacts_resolver'
            '.UnprocessedArtifactsResolver')
        resolver1.config_json = '{"execution_type_name": "Transform"}'
        resolver2 = my_transform.inputs.resolver_config.resolver_steps.add()
        resolver2.class_path = (
            'tfx.dsl.input_resolution.strategies.latest_artifact_strategy'
            '.LatestArtifactStrategy')
        resolver2.config_json = '{}'

        with self.get_metadata() as m:
            ex1 = self.make_examples(uri='a')
            ex2 = self.make_examples(uri='b')
            output_artifacts = self.fake_execute(
                m,
                my_example_gen,
                input_map=None,
                output_map={'output_examples': [ex1]})
            ex1 = output_artifacts['output_examples'][0]
            output_artifacts = self.fake_execute(
                m,
                my_example_gen,
                input_map=None,
                output_map={'output_examples': [ex2]})
            ex2 = output_artifacts['output_examples'][0]
            self.fake_execute(m,
                              my_transform,
                              input_map={'examples': [ex1]},
                              output_map=None)
            self.fake_execute(m,
                              my_transform,
                              input_map={'examples': [ex2]},
                              output_map=None)

            result = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=my_transform.inputs)

        self.assertIsNone(result)
Ejemplo n.º 13
0
    def testResolveInputArtifactsOutputKeyUnset(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(
                self._testdata_dir,
                'pipeline_for_input_resolver_test_output_key_unset.pbtxt'),
            pipeline)
        my_trainer = pipeline.nodes[0].pipeline_node
        my_pusher = pipeline.nodes[1].pipeline_node

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        with metadata.Metadata(connection_config=connection_config) as m:
            # Publishes Trainer with one output channels. `output_model`
            # will be consumed by the Pusher in the different run.
            output_model = types.Artifact(
                my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model.uri = 'my_output_model_uri'
            contexts = context_lib.register_contexts_if_not_exists(
                m, my_trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, my_trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model],
                })
            # Gets inputs for pusher. Should get back what the first Model
            # published in the `output_model` channel.
            pusher_inputs = inputs_utils.resolve_input_artifacts(
                m, my_pusher.inputs)
            self.assertEqual(len(pusher_inputs), 1)
            self.assertEqual(len(pusher_inputs['model']), 1)
            self.assertProtoPartiallyEquals(
                output_model.mlmd_artifact,
                pusher_inputs['model'][0].mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
Ejemplo n.º 14
0
    def testResolveInputArtifactsOutputKeyUnset(self):
        pipeline = self.load_pipeline_proto(
            'pipeline_for_input_resolver_test_output_key_unset.pbtxt')
        my_trainer = pipeline.nodes[0].pipeline_node
        my_pusher = pipeline.nodes[1].pipeline_node

        with self.get_metadata() as m:
            # Publishes Trainer with one output channels. `output_model`
            # will be consumed by the Pusher in the different run.
            output_model = self.make_model(uri='my_output_model_uri')
            output_artifacts = self.fake_execute(
                m,
                my_trainer,
                input_map=None,
                output_map={'model': [output_model]})
            output_model = output_artifacts['model'][0]

            # Gets inputs for pusher. Should get back what the first Model
            # published in the `output_model` channel.
            pusher_inputs = inputs_utils.resolve_input_artifacts(
                m, my_pusher.inputs)
            self.assertArtifactMapEqual({'model': [output_model]},
                                        pusher_inputs)
Ejemplo n.º 15
0
    def _prepare_execution(self) -> _PrepareExecutionResult:
        """Prepares inputs, outputs and execution properties for actual execution."""
        # TODO(b/150979622): handle the edge case that the component get evicted
        # between successful pushlish and stateful working dir being clean up.
        # Otherwise following retries will keep failing because of duplicate
        # publishes.
        with self._mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.register_contexts_if_not_exists(
                metadata_handler=m, node_contexts=self._pipeline_node.contexts)

            # 2. Resolves inputs an execution properties.
            exec_properties = inputs_utils.resolve_parameters(
                node_parameters=self._pipeline_node.parameters)
            input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=self._pipeline_node.inputs)
            # 3. If not all required inputs are met. Return ExecutionInfo with
            # is_execution_needed being false. No publish will happen so down stream
            # nodes won't be triggered.
            if input_artifacts is None:
                return _PrepareExecutionResult(
                    execution_info=data_types.ExecutionInfo(),
                    contexts=contexts,
                    is_execution_needed=False)

            # 4. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=self._pipeline_node.node_info.type,
                contexts=contexts,
                input_artifacts=input_artifacts,
                exec_properties=exec_properties)

            # 5. Resolve output
            output_artifacts = self._output_resolver.generate_output_artifacts(
                execution.id)

        # If there is a custom driver, runs it.
        if self._driver_operator:
            driver_output = self._driver_operator.run_driver(
                data_types.ExecutionInfo(
                    input_dict=input_artifacts,
                    output_dict=output_artifacts,
                    exec_properties=exec_properties,
                    execution_output_uri=self._output_resolver.
                    get_driver_output_uri()))
            self._update_with_driver_output(driver_output, exec_properties,
                                            output_artifacts)

        # We reconnect to MLMD here because the custom driver closes MLMD connection
        # on returning.
        with self._mlmd_connection as m:
            # 6. Check cached result
            cache_context = cache_utils.get_cache_context(
                metadata_handler=m,
                pipeline_node=self._pipeline_node,
                pipeline_info=self._pipeline_info,
                input_artifacts=input_artifacts,
                output_artifacts=output_artifacts,
                parameters=exec_properties)
            contexts.append(cache_context)
            cached_outputs = cache_utils.get_cached_outputs(
                metadata_handler=m, cache_context=cache_context)

            # 7. Should cache be used?
            if (self._pipeline_node.execution_options.caching_options.
                    enable_cache and cached_outputs):
                # Publishes cache result
                execution_publish_utils.publish_cached_execution(
                    metadata_handler=m,
                    contexts=contexts,
                    execution_id=execution.id,
                    output_artifacts=cached_outputs)
                return _PrepareExecutionResult(
                    execution_info=data_types.ExecutionInfo(
                        execution_id=execution.id),
                    execution_metadata=execution,
                    contexts=contexts,
                    is_execution_needed=False)

            pipeline_run_id = (self._pipeline_runtime_spec.pipeline_run_id.
                               field_value.string_value)

            # 8. Going to trigger executor.
            return _PrepareExecutionResult(
                execution_info=data_types.ExecutionInfo(
                    execution_id=execution.id,
                    input_dict=input_artifacts,
                    output_dict=output_artifacts,
                    exec_properties=exec_properties,
                    execution_output_uri=self._output_resolver.
                    get_executor_output_uri(execution.id),
                    stateful_working_dir=(self._output_resolver.
                                          get_stateful_working_directory()),
                    tmp_dir=self._output_resolver.make_tmp_dir(execution.id),
                    pipeline_node=self._pipeline_node,
                    pipeline_info=self._pipeline_info,
                    pipeline_run_id=pipeline_run_id),
                execution_metadata=execution,
                contexts=contexts,
                is_execution_needed=True)
Ejemplo n.º 16
0
    def testResolverInputsArtifacts(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(self._testdata_dir,
                         'pipeline_for_input_resolver_test.pbtxt'), pipeline)
        my_example_gen = pipeline.nodes[0].pipeline_node
        another_example_gen = pipeline.nodes[1].pipeline_node
        my_transform = pipeline.nodes[2].pipeline_node
        my_trainer = pipeline.nodes[3].pipeline_node

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        with metadata.Metadata(connection_config=connection_config) as m:
            # Publishes first ExampleGen with two output channels. `output_examples`
            # will be consumed by downstream Transform.
            output_example = types.Artifact(
                my_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            output_example.uri = 'my_examples_uri'
            side_examples = types.Artifact(
                my_example_gen.outputs.outputs['side_examples'].artifact_spec.
                type)
            side_examples.uri = 'side_examples_uri'
            contexts = context_lib.register_contexts_if_not_exists(
                m, my_example_gen.contexts)
            execution = execution_publish_utils.register_execution(
                m, my_example_gen.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'output_examples': [output_example],
                    'another_examples': [side_examples]
                })

            # Publishes second ExampleGen with one output channel with the same output
            # key as the first ExampleGen. However this is not consumed by downstream
            # nodes.
            another_output_example = types.Artifact(
                another_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            another_output_example.uri = 'another_examples_uri'
            contexts = context_lib.register_contexts_if_not_exists(
                m, another_example_gen.contexts)
            execution = execution_publish_utils.register_execution(
                m, another_example_gen.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'output_examples': [another_output_example],
                })

            # Gets inputs for transform. Should get back what the first ExampleGen
            # published in the `output_examples` channel.
            transform_inputs = inputs_utils.resolve_input_artifacts(
                m, my_transform.inputs)
            self.assertEqual(len(transform_inputs), 1)
            self.assertEqual(len(transform_inputs['examples']), 1)
            self.assertProtoPartiallyEquals(
                transform_inputs['examples'][0].mlmd_artifact,
                output_example.mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])

            # Tries to resolve inputs for trainer. As trainer also requires min_count
            # for both input channels (from example_gen and from transform) but we did
            # not publish anything from transform, it should return nothing.
            self.assertIsNone(
                inputs_utils.resolve_input_artifacts(m, my_trainer.inputs))
Ejemplo n.º 17
0
    def testSuccess(self):
        with self._mlmd_connection as m:
            # Publishes two models which will be consumed by downstream resolver.
            output_model_1 = types.Artifact(
                self._my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_1.uri = 'my_model_uri_1'

            output_model_2 = types.Artifact(
                self._my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_2.uri = 'my_model_uri_2'

            contexts = context_lib.prepare_contexts(m,
                                                    self._my_trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, self._my_trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model_1, output_model_2],
                })

        handler = resolver_node_handler.ResolverNodeHandler()
        execution_metadata = handler.run(
            mlmd_connection=self._mlmd_connection,
            pipeline_node=self._resolver_node,
            pipeline_info=self._pipeline_info,
            pipeline_runtime_spec=self._pipeline_runtime_spec)

        with self._mlmd_connection as m:
            # There is no way to directly verify the output artifact of the resolver
            # So here a fake downstream component is created which listens to the
            # resolver's output and we verify its input.
            down_stream_node = text_format.Parse(
                """
        inputs {
          inputs {
            key: "input_models"
            value {
              channels {
                producer_node_query {
                  id: "my_resolver"
                }
                context_queries {
                  type {
                    name: "pipeline"
                  }
                  name {
                    field_value {
                      string_value: "my_pipeline"
                    }
                  }
                }
                context_queries {
                  type {
                    name: "component"
                  }
                  name {
                    field_value {
                      string_value: "my_resolver"
                    }
                  }
                }
                artifact_query {
                  type {
                    name: "Model"
                  }
                }
                output_key: "models"
              }
              min_count: 1
            }
          }
        }
        upstream_nodes: "my_resolver"
        """, pipeline_pb2.PipelineNode())
            downstream_input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=down_stream_node.inputs)
            downstream_input_model = downstream_input_artifacts['input_models']
            self.assertLen(downstream_input_model, 1)
            self.assertProtoPartiallyEquals(
                """
          id: 2
          type_id: 5
          uri: "my_model_uri_2"
          state: LIVE""",
                downstream_input_model[0].mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
            [execution] = m.store.get_executions_by_id([execution_metadata.id])

            self.assertProtoPartiallyEquals("""
          id: 2
          type_id: 6
          last_known_state: COMPLETE
          """,
                                            execution,
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
Ejemplo n.º 18
0
    def testLauncher_ReEntry(self):
        # Some executors or runtime environment may reschedule the launcher job
        # before the launcher job can publish any results of the execution to MLMD.
        # The launcher should reuse the previous execution and proceed to a
        # successful execution.
        self.reloadPipelineWithNewRunId()
        LauncherTest.fakeUpstreamOutputs(self._mlmd_connection,
                                         self._example_gen, self._transform)

        def create_test_launcher(executor_operators):
            return launcher.Launcher(
                pipeline_node=self._trainer,
                mlmd_connection=self._mlmd_connection,
                pipeline_info=self._pipeline_info,
                pipeline_runtime_spec=self._pipeline_runtime_spec,
                executor_spec=self._trainer_executor_spec,
                custom_executor_operators=executor_operators)

        test_launcher = create_test_launcher(
            {_PYTHON_CLASS_EXECUTABLE_SPEC: _FakeCrashingExecutorOperator})

        # The first launch simulates the launcher being restarted by preventing the
        # publishing of any results to MLMD.
        with contextlib.ExitStack() as stack:
            stack.enter_context(
                mock.patch.object(test_launcher, '_publish_failed_execution'))
            stack.enter_context(
                mock.patch.object(test_launcher,
                                  '_clean_up_stateless_execution_info'))
            stack.enter_context(self.assertRaises(FakeError))
            test_launcher.launch()

        # Retrieve execution from the first launch, which should be in RUNNING
        # state.
        with self._mlmd_connection as m:
            contexts = context_lib.prepare_contexts(
                metadata_handler=m,
                node_contexts=test_launcher._pipeline_node.contexts)
            exec_properties = data_types_utils.build_parsed_value_dict(
                inputs_utils.resolve_parameters_with_schema(
                    node_parameters=test_launcher._pipeline_node.parameters))
            input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m,
                node_inputs=test_launcher._pipeline_node.inputs)
            first_execution = test_launcher._register_or_reuse_execution(
                metadata_handler=m,
                contexts=contexts,
                input_artifacts=input_artifacts,
                exec_properties=exec_properties)
            self.assertEqual(first_execution.last_known_state,
                             metadata_store_pb2.Execution.RUNNING)

        # Create a second test launcher. It should reuse the previous execution.
        second_test_launcher = create_test_launcher(
            {_PYTHON_CLASS_EXECUTABLE_SPEC: _FakeExecutorOperator})
        execution_info = second_test_launcher.launch()

        with self._mlmd_connection as m:
            self.assertEqual(first_execution.id, execution_info.execution_id)
            executions = m.store.get_executions_by_id(
                [execution_info.execution_id])
            self.assertLen(executions, 1)
            self.assertEqual(executions.pop().last_known_state,
                             metadata_store_pb2.Execution.COMPLETE)

        # Create a third test launcher. It should not require an execution.
        third_test_launcher = create_test_launcher(
            {_PYTHON_CLASS_EXECUTABLE_SPEC: _FakeExecutorOperator})
        execution_preparation_result = third_test_launcher._prepare_execution()
        self.assertFalse(execution_preparation_result.is_execution_needed)