Example #1
0
    def fakeUpstreamOutputs(mlmd_connection: metadata.Metadata,
                            example_gen: pipeline_pb2.PipelineNode,
                            transform: pipeline_pb2.PipelineNode):

        with mlmd_connection as m:
            if example_gen:
                # Publishes ExampleGen output.
                output_example = types.Artifact(
                    example_gen.outputs.outputs['output_examples'].
                    artifact_spec.type)
                output_example.uri = 'my_examples_uri'
                contexts = context_lib.prepare_contexts(
                    m, example_gen.contexts)
                execution = execution_publish_utils.register_execution(
                    m, example_gen.node_info.type, contexts)
                execution_publish_utils.publish_succeeded_execution(
                    m, execution.id, contexts, {
                        'output_examples': [output_example],
                    })

            if transform:
                # Publishes Transform output.
                output_transform_graph = types.Artifact(
                    transform.outputs.outputs['transform_graph'].artifact_spec.
                    type)
                output_example.uri = 'my_transform_graph_uri'
                contexts = context_lib.prepare_contexts(m, transform.contexts)
                execution = execution_publish_utils.register_execution(
                    m, transform.node_info.type, contexts)
                execution_publish_utils.publish_succeeded_execution(
                    m, execution.id, contexts, {
                        'transform_graph': [output_transform_graph],
                    })
Example #2
0
def generate_resolved_info(metadata_handler: metadata.Metadata,
                           node: pipeline_pb2.PipelineNode) -> ResolvedInfo:
  """Returns a `ResolvedInfo` object for executing the node.

  Args:
    metadata_handler: A handler to access MLMD db.
    node: The pipeline node for which to generate.

  Returns:
    A `ResolvedInfo` with input resolutions.
  """
  # Register node contexts.
  contexts = context_lib.prepare_contexts(
      metadata_handler=metadata_handler, node_contexts=node.contexts)

  # Resolve execution properties.
  exec_properties = inputs_utils.resolve_parameters(
      node_parameters=node.parameters)

  # Resolve inputs.
  input_artifacts = inputs_utils.resolve_input_artifacts(
      metadata_handler=metadata_handler, node_inputs=node.inputs)

  return ResolvedInfo(
      contexts=contexts,
      exec_properties=exec_properties,
      input_artifacts=input_artifacts)
Example #3
0
    def testRegisterContexts(self):
        node_contexts = pipeline_pb2.NodeContexts()
        self.load_proto_from_text(
            os.path.join(self._testdata_dir, 'node_context_spec.pbtxt'),
            node_contexts)
        with metadata.Metadata(connection_config=self._connection_config) as m:
            context_lib.prepare_contexts(metadata_handler=m,
                                         node_contexts=node_contexts)
            # Duplicated call should succeed.
            contexts = context_lib.prepare_contexts(
                metadata_handler=m, node_contexts=node_contexts)

            got_context_type_one = m.store.get_context_type(
                'my_context_type_one')
            got_context_type_one.ClearField('id')
            self.assertProtoEquals(
                """
          name: 'my_context_type_one'
          """, got_context_type_one)
            got_context_type_two = m.store.get_context_type(
                'my_context_type_two')
            got_context_type_two.ClearField('id')

            self.assertProtoEquals(
                """
          name: 'my_context_type_two'
          """, got_context_type_two)
            self.assertEqual(
                contexts[0],
                m.store.get_context_by_type_and_name('my_context_type_one',
                                                     'my_context_one'))
            self.assertEqual(
                contexts[1],
                m.store.get_context_by_type_and_name('my_context_type_one',
                                                     'my_context_two'))
            self.assertEqual(
                contexts[2],
                m.store.get_context_by_type_and_name('my_context_type_two',
                                                     'my_context_three'))
            self.assertEqual(
                contexts[0].custom_properties['property_a'].int_value, 1)
            self.assertEqual(
                contexts[1].custom_properties['property_a'].int_value, 2)
            self.assertEqual(
                contexts[2].custom_properties['property_a'].int_value, 3)
            self.assertEqual(
                contexts[2].custom_properties['property_b'].string_value, '4')
Example #4
0
 def fake_execute(self, metadata_handler, pipeline_node, input_map,
                  output_map):
     contexts = context_lib.prepare_contexts(metadata_handler,
                                             pipeline_node.contexts)
     execution = execution_publish_utils.register_execution(
         metadata_handler, pipeline_node.node_info.type, contexts,
         input_map)
     return execution_publish_utils.publish_succeeded_execution(
         metadata_handler, execution.id, contexts, output_map)
Example #5
0
def fake_cached_execution(mlmd_connection, cache_context, component):
    """Writes cached execution; MLMD must have previous execution associated with cache_context."""
    with mlmd_connection as m:
        cached_outputs = cache_utils.get_cached_outputs(
            m, cache_context=cache_context)
        contexts = context_lib.prepare_contexts(m, component.contexts)
        execution = execution_publish_utils.register_execution(
            m, component.node_info.type, contexts)
        execution_publish_utils.publish_cached_execution(
            m,
            contexts=contexts,
            execution_id=execution.id,
            output_artifacts=cached_outputs)
Example #6
0
    def run(
        self, mlmd_connection: metadata.Metadata,
        pipeline_node: pipeline_pb2.PipelineNode,
        pipeline_info: pipeline_pb2.PipelineInfo,
        pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec
    ) -> data_types.ExecutionInfo:
        """Runs Resolver specific logic.

    Args:
      mlmd_connection: ML metadata connection.
      pipeline_node: The specification of the node that this launcher lauches.
      pipeline_info: The information of the pipeline that this node runs in.
      pipeline_runtime_spec: The runtime information of the pipeline that this
        node runs in.

    Returns:
      The execution of the run.
    """
        logging.info('Running as an resolver node.')
        with mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.prepare_contexts(
                metadata_handler=m, node_contexts=pipeline_node.contexts)

            # 2. Resolves inputs an execution properties.
            exec_properties = inputs_utils.resolve_parameters(
                node_parameters=pipeline_node.parameters)
            input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=pipeline_node.inputs)

            # 3. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=pipeline_node.node_info.type,
                contexts=contexts,
                exec_properties=exec_properties)

            # 4. Publish the execution as a cached execution with
            # resolved input artifact as the output artifacts.
            execution_publish_utils.publish_internal_execution(
                metadata_handler=m,
                contexts=contexts,
                execution_id=execution.id,
                output_artifacts=input_artifacts)

            return data_types.ExecutionInfo(execution_id=execution.id,
                                            input_dict=input_artifacts,
                                            output_dict=input_artifacts,
                                            exec_properties=exec_properties,
                                            pipeline_node=pipeline_node,
                                            pipeline_info=pipeline_info)
Example #7
0
def fake_trainer_output(mlmd_connection, trainer, execution=None):
  """Writes fake trainer output and execution to MLMD."""
  with mlmd_connection as m:
    output_trainer_model = types.Artifact(
        trainer.outputs.outputs['model'].artifact_spec.type)
    output_trainer_model.uri = 'my_trainer_model_uri'
    contexts = context_lib.prepare_contexts(m, trainer.contexts)
    if not execution:
      execution = execution_publish_utils.register_execution(
          m, trainer.node_info.type, contexts)
    execution_publish_utils.publish_succeeded_execution(
        m, execution.id, contexts, {
            'model': [output_trainer_model],
        })
Example #8
0
def fake_example_gen_run_with_handle(mlmd_handle, example_gen, span, version):
  """Writes fake example_gen output and successful execution to MLMD."""
  output_example = types.Artifact(
      example_gen.outputs.outputs['output_examples'].artifact_spec.type)
  output_example.set_int_custom_property('span', span)
  output_example.set_int_custom_property('version', version)
  output_example.uri = 'my_examples_uri'
  contexts = context_lib.prepare_contexts(mlmd_handle, example_gen.contexts)
  execution = execution_publish_utils.register_execution(
      mlmd_handle, example_gen.node_info.type, contexts)
  execution_publish_utils.publish_succeeded_execution(
      mlmd_handle, execution.id, contexts, {
          'output_examples': [output_example],
      })
Example #9
0
def fake_transform_output(mlmd_connection, transform, execution=None):
  """Writes fake transform output and execution to MLMD."""
  with mlmd_connection as m:
    output_transform_graph = types.Artifact(
        transform.outputs.outputs['transform_graph'].artifact_spec.type)
    output_transform_graph.uri = 'my_transform_graph_uri'
    contexts = context_lib.prepare_contexts(m, transform.contexts)
    if not execution:
      execution = execution_publish_utils.register_execution(
          m, transform.node_info.type, contexts)
    execution_publish_utils.publish_succeeded_execution(
        m, execution.id, contexts, {
            'transform_graph': [output_transform_graph],
        })
Example #10
0
def fake_component_output_with_handle(mlmd_handle,
                                      component,
                                      execution=None,
                                      active=False):
    """Writes fake component output and execution to MLMD."""
    output_key, output_value = next(iter(component.outputs.outputs.items()))
    output = types.Artifact(output_value.artifact_spec.type)
    output.uri = str(uuid.uuid4())
    contexts = context_lib.prepare_contexts(mlmd_handle, component.contexts)
    if not execution:
        execution = execution_publish_utils.register_execution(
            mlmd_handle, component.node_info.type, contexts)
    if not active:
        execution_publish_utils.publish_succeeded_execution(
            mlmd_handle, execution.id, contexts, {output_key: [output]})
Example #11
0
def generate_resolved_info(
        metadata_handler: metadata.Metadata,
        node: pipeline_pb2.PipelineNode) -> Optional[ResolvedInfo]:
    """Returns a `ResolvedInfo` object for executing the node or `None` to skip.

  Args:
    metadata_handler: A handler to access MLMD db.
    node: The pipeline node for which to generate.

  Returns:
    A `ResolvedInfo` with input resolutions or `None` if execution should be
    skipped.

  Raises:
    NotImplementedError: Multiple dicts returned by inputs_utils
      resolve_input_artifacts_v2, which is currently not supported.
  """
    # Register node contexts.
    contexts = context_lib.prepare_contexts(metadata_handler=metadata_handler,
                                            node_contexts=node.contexts)

    # Resolve execution properties.
    exec_properties = data_types_utils.build_parsed_value_dict(
        inputs_utils.resolve_parameters_with_schema(
            node_parameters=node.parameters))

    # Resolve inputs.
    try:
        resolved_input_artifacts = inputs_utils.resolve_input_artifacts_v2(
            metadata_handler=metadata_handler, pipeline_node=node)
    except exceptions.InputResolutionError as e:
        logging.warning(
            'Input resolution error raised for node: %s; error: %s',
            node.node_info.id, e)
        resolved_input_artifacts = None
    else:
        if isinstance(resolved_input_artifacts, inputs_utils.Skip):
            return None
        assert isinstance(resolved_input_artifacts, inputs_utils.Trigger)
        assert resolved_input_artifacts
        # TODO(b/197741942): Support multiple dicts.
        if len(resolved_input_artifacts) > 1:
            raise NotImplementedError(
                'Handling more than one input dicts not implemented.')

    return ResolvedInfo(contexts=contexts,
                        exec_properties=exec_properties,
                        input_artifacts=resolved_input_artifacts)
Example #12
0
 def fakeExampleGenOutput(mlmd_connection: metadata.Metadata,
                          example_gen: pipeline_pb2.PipelineNode, span: int,
                          version: int):
   with mlmd_connection as m:
     output_example = types.Artifact(
         example_gen.outputs.outputs['output_examples'].artifact_spec.type)
     output_example.set_int_custom_property('span', span)
     output_example.set_int_custom_property('version', version)
     output_example.uri = 'my_examples_uri'
     contexts = context_lib.prepare_contexts(m, example_gen.contexts)
     execution = execution_publish_utils.register_execution(
         m, example_gen.node_info.type, contexts)
     execution_publish_utils.publish_succeeded_execution(
         m, execution.id, contexts, {
             'output_examples': [output_example],
         })
Example #13
0
 def _generate_contexts(self, metadata_handler):
     context_spec = pipeline_pb2.NodeContexts()
     text_format.Parse(
         """
     contexts {
       type {name: 'pipeline_context'}
       name {
         field_value {string_value: 'my_pipeline'}
       }
     }
     contexts {
       type {name: 'component_context'}
       name {
         field_value {string_value: 'my_component'}
       }
     }""", context_spec)
     return context_lib.prepare_contexts(metadata_handler, context_spec)
Example #14
0
    def testResolverWithResolverPolicy(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(self._testdata_dir,
                         'pipeline_for_input_resolver_test.pbtxt'), pipeline)
        my_example_gen = pipeline.nodes[0].pipeline_node
        my_transform = pipeline.nodes[2].pipeline_node

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        with metadata.Metadata(connection_config=connection_config) as m:
            # Publishes first ExampleGen with two output channels. `output_examples`
            # will be consumed by downstream Transform.
            output_example_1 = types.Artifact(
                my_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            output_example_1.uri = 'my_examples_uri_1'

            output_example_2 = types.Artifact(
                my_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            output_example_2.uri = 'my_examples_uri_2'

            contexts = context_lib.prepare_contexts(m, my_example_gen.contexts)
            execution = execution_publish_utils.register_execution(
                m, my_example_gen.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'output_examples': [output_example_1, output_example_2],
                })

            my_transform.inputs.resolver_config.resolver_policy = (
                pipeline_pb2.ResolverConfig.LATEST_ARTIFACT)

            # Gets inputs for transform. Should get back what the first ExampleGen
            # published in the `output_examples` channel.
            transform_inputs = inputs_utils.resolve_input_artifacts(
                m, my_transform.inputs)
            self.assertEqual(len(transform_inputs), 1)
            self.assertEqual(len(transform_inputs['examples']), 1)
            self.assertProtoPartiallyEquals(
                transform_inputs['examples'][0].mlmd_artifact,
                output_example_2.mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
Example #15
0
    def testResolveInputArtifactsOutputKeyUnset(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(
                self._testdata_dir,
                'pipeline_for_input_resolver_test_output_key_unset.pbtxt'),
            pipeline)
        my_trainer = pipeline.nodes[0].pipeline_node
        my_pusher = pipeline.nodes[1].pipeline_node

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        with metadata.Metadata(connection_config=connection_config) as m:
            # Publishes Trainer with one output channels. `output_model`
            # will be consumed by the Pusher in the different run.
            output_model = types.Artifact(
                my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model.uri = 'my_output_model_uri'
            contexts = context_lib.prepare_contexts(m, my_trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, my_trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model],
                })
            # Gets inputs for pusher. Should get back what the first Model
            # published in the `output_model` channel.
            pusher_inputs = inputs_utils.resolve_input_artifacts(
                m, my_pusher.inputs)
            self.assertEqual(len(pusher_inputs), 1)
            self.assertEqual(len(pusher_inputs['model']), 1)
            self.assertProtoPartiallyEquals(
                output_model.mlmd_artifact,
                pusher_inputs['model'][0].mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
Example #16
0
    def _prepare_execution(self) -> _PrepareExecutionResult:
        """Prepares inputs, outputs and execution properties for actual execution."""
        # TODO(b/150979622): handle the edge case that the component get evicted
        # between successful pushlish and stateful working dir being clean up.
        # Otherwise following retries will keep failing because of duplicate
        # publishes.
        with self._mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.prepare_contexts(
                metadata_handler=m, node_contexts=self._pipeline_node.contexts)

            # 2. Resolves inputs an execution properties.
            exec_properties = inputs_utils.resolve_parameters(
                node_parameters=self._pipeline_node.parameters)
            input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=self._pipeline_node.inputs)
            # 3. If not all required inputs are met. Return ExecutionInfo with
            # is_execution_needed being false. No publish will happen so down stream
            # nodes won't be triggered.
            if input_artifacts is None:
                logging.info(
                    'No all required input are ready, abandoning execution.')
                return _PrepareExecutionResult(
                    execution_info=data_types.ExecutionInfo(),
                    contexts=contexts,
                    is_execution_needed=False)

            # 4. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=self._pipeline_node.node_info.type,
                contexts=contexts,
                input_artifacts=input_artifacts,
                exec_properties=exec_properties)

            # 5. Resolve output
            output_artifacts = self._output_resolver.generate_output_artifacts(
                execution.id)

        # If there is a custom driver, runs it.
        if self._driver_operator:
            driver_output = self._driver_operator.run_driver(
                data_types.ExecutionInfo(
                    input_dict=input_artifacts,
                    output_dict=output_artifacts,
                    exec_properties=exec_properties,
                    execution_output_uri=self._output_resolver.
                    get_driver_output_uri()))
            self._update_with_driver_output(driver_output, exec_properties,
                                            output_artifacts)

        # We reconnect to MLMD here because the custom driver closes MLMD connection
        # on returning.
        with self._mlmd_connection as m:
            # 6. Check cached result
            cache_context = cache_utils.get_cache_context(
                metadata_handler=m,
                pipeline_node=self._pipeline_node,
                pipeline_info=self._pipeline_info,
                executor_spec=self._executor_spec,
                input_artifacts=input_artifacts,
                output_artifacts=output_artifacts,
                parameters=exec_properties)
            contexts.append(cache_context)
            cached_outputs = cache_utils.get_cached_outputs(
                metadata_handler=m, cache_context=cache_context)

            # 7. Should cache be used?
            if (self._pipeline_node.execution_options.caching_options.
                    enable_cache and cached_outputs):
                # Publishes cache result
                execution_publish_utils.publish_cached_execution(
                    metadata_handler=m,
                    contexts=contexts,
                    execution_id=execution.id,
                    output_artifacts=cached_outputs)
                logging.info('An cached execusion %d is used.', execution.id)
                return _PrepareExecutionResult(
                    execution_info=data_types.ExecutionInfo(
                        execution_id=execution.id),
                    execution_metadata=execution,
                    contexts=contexts,
                    is_execution_needed=False)

            pipeline_run_id = (self._pipeline_runtime_spec.pipeline_run_id.
                               field_value.string_value)

            # 8. Going to trigger executor.
            logging.info('Going to run a new execution %d', execution.id)
            return _PrepareExecutionResult(
                execution_info=data_types.ExecutionInfo(
                    execution_id=execution.id,
                    input_dict=input_artifacts,
                    output_dict=output_artifacts,
                    exec_properties=exec_properties,
                    execution_output_uri=self._output_resolver.
                    get_executor_output_uri(execution.id),
                    stateful_working_dir=(self._output_resolver.
                                          get_stateful_working_directory()),
                    tmp_dir=self._output_resolver.make_tmp_dir(execution.id),
                    pipeline_node=self._pipeline_node,
                    pipeline_info=self._pipeline_info,
                    pipeline_run_id=pipeline_run_id),
                execution_metadata=execution,
                contexts=contexts,
                is_execution_needed=True)
Example #17
0
    def testLauncher_ReEntry(self):
        # Some executors or runtime environment may reschedule the launcher job
        # before the launcher job can publish any results of the execution to MLMD.
        # The launcher should reuse the previous execution and proceed to a
        # successful execution.
        self.reloadPipelineWithNewRunId()
        LauncherTest.fakeUpstreamOutputs(self._mlmd_connection,
                                         self._example_gen, self._transform)

        def create_test_launcher(executor_operators):
            return launcher.Launcher(
                pipeline_node=self._trainer,
                mlmd_connection=self._mlmd_connection,
                pipeline_info=self._pipeline_info,
                pipeline_runtime_spec=self._pipeline_runtime_spec,
                executor_spec=self._trainer_executor_spec,
                custom_executor_operators=executor_operators)

        test_launcher = create_test_launcher(
            {_PYTHON_CLASS_EXECUTABLE_SPEC: _FakeCrashingExecutorOperator})

        # The first launch simulates the launcher being restarted by preventing the
        # publishing of any results to MLMD.
        with contextlib.ExitStack() as stack:
            stack.enter_context(
                mock.patch.object(test_launcher, '_publish_failed_execution'))
            stack.enter_context(
                mock.patch.object(test_launcher,
                                  '_clean_up_stateless_execution_info'))
            stack.enter_context(self.assertRaises(FakeError))
            test_launcher.launch()

        # Retrieve execution from the first launch, which should be in RUNNING
        # state.
        with self._mlmd_connection as m:
            contexts = context_lib.prepare_contexts(
                metadata_handler=m,
                node_contexts=test_launcher._pipeline_node.contexts)
            exec_properties = data_types_utils.build_parsed_value_dict(
                inputs_utils.resolve_parameters_with_schema(
                    node_parameters=test_launcher._pipeline_node.parameters))
            input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m,
                node_inputs=test_launcher._pipeline_node.inputs)
            first_execution = test_launcher._register_or_reuse_execution(
                metadata_handler=m,
                contexts=contexts,
                input_artifacts=input_artifacts,
                exec_properties=exec_properties)
            self.assertEqual(first_execution.last_known_state,
                             metadata_store_pb2.Execution.RUNNING)

        # Create a second test launcher. It should reuse the previous execution.
        second_test_launcher = create_test_launcher(
            {_PYTHON_CLASS_EXECUTABLE_SPEC: _FakeExecutorOperator})
        execution_info = second_test_launcher.launch()

        with self._mlmd_connection as m:
            self.assertEqual(first_execution.id, execution_info.execution_id)
            executions = m.store.get_executions_by_id(
                [execution_info.execution_id])
            self.assertLen(executions, 1)
            self.assertEqual(executions.pop().last_known_state,
                             metadata_store_pb2.Execution.COMPLETE)

        # Create a third test launcher. It should not require an execution.
        third_test_launcher = create_test_launcher(
            {_PYTHON_CLASS_EXECUTABLE_SPEC: _FakeExecutorOperator})
        execution_preparation_result = third_test_launcher._prepare_execution()
        self.assertFalse(execution_preparation_result.is_execution_needed)
    def test_resolver_task_scheduler(self):
        with self._mlmd_connection as m:
            # Publishes two models which will be consumed by downstream resolver.
            output_model_1 = types.Artifact(
                self._trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_1.uri = 'my_model_uri_1'

            output_model_2 = types.Artifact(
                self._trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_2.uri = 'my_model_uri_2'

            contexts = context_lib.prepare_contexts(m, self._trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, self._trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model_1, output_model_2],
                })

        task_queue = tq.TaskQueue()

        # Verify that resolver task is generated.
        [resolver_task] = test_utils.run_generator_and_test(
            test_case=self,
            mlmd_connection=self._mlmd_connection,
            generator_class=sptg.SyncPipelineTaskGenerator,
            pipeline=self._pipeline,
            task_queue=task_queue,
            use_task_queue=False,
            service_job_manager=None,
            num_initial_executions=1,
            num_tasks_generated=1,
            num_new_executions=1,
            num_active_executions=1,
            expected_exec_nodes=[self._resolver_node],
            ignore_update_node_state_tasks=True)

        with self._mlmd_connection as m:
            # Run resolver task scheduler and publish results.
            ts_result = resolver_task_scheduler.ResolverTaskScheduler(
                mlmd_handle=m, pipeline=self._pipeline,
                task=resolver_task).schedule()
            self.assertEqual(status_lib.Code.OK, ts_result.status.code)
            self.assertIsInstance(ts_result.output,
                                  task_scheduler.ResolverNodeOutput)
            self.assertCountEqual(
                ['resolved_model'],
                ts_result.output.resolved_input_artifacts.keys())
            models = ts_result.output.resolved_input_artifacts[
                'resolved_model']
            self.assertLen(models, 1)
            self.assertEqual('my_model_uri_2', models[0].mlmd_artifact.uri)
            tm._publish_execution_results(m, resolver_task, ts_result)

        # Verify resolver node output is input to the downstream consumer node.
        [consumer_task] = test_utils.run_generator_and_test(
            test_case=self,
            mlmd_connection=self._mlmd_connection,
            generator_class=sptg.SyncPipelineTaskGenerator,
            pipeline=self._pipeline,
            task_queue=task_queue,
            use_task_queue=False,
            service_job_manager=None,
            num_initial_executions=2,
            num_tasks_generated=1,
            num_new_executions=1,
            num_active_executions=1,
            expected_exec_nodes=[self._consumer_node],
            ignore_update_node_state_tasks=True)
        self.assertCountEqual(['resolved_model'],
                              consumer_task.input_artifacts.keys())
        input_models = consumer_task.input_artifacts['resolved_model']
        self.assertLen(input_models, 1)
        self.assertEqual('my_model_uri_2', input_models[0].mlmd_artifact.uri)
Example #19
0
    def testSuccess(self):
        with self._mlmd_connection as m:
            # Publishes two models which will be consumed by downstream resolver.
            output_model_1 = types.Artifact(
                self._my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_1.uri = 'my_model_uri_1'

            output_model_2 = types.Artifact(
                self._my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_2.uri = 'my_model_uri_2'

            contexts = context_lib.prepare_contexts(m,
                                                    self._my_trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, self._my_trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model_1, output_model_2],
                })

        handler = resolver_node_handler.ResolverNodeHandler()
        execution_metadata = handler.run(
            mlmd_connection=self._mlmd_connection,
            pipeline_node=self._resolver_node,
            pipeline_info=self._pipeline_info,
            pipeline_runtime_spec=self._pipeline_runtime_spec)

        with self._mlmd_connection as m:
            # There is no way to directly verify the output artifact of the resolver
            # So here a fake downstream component is created which listens to the
            # resolver's output and we verify its input.
            down_stream_node = text_format.Parse(
                """
        inputs {
          inputs {
            key: "input_models"
            value {
              channels {
                producer_node_query {
                  id: "my_resolver"
                }
                context_queries {
                  type {
                    name: "pipeline"
                  }
                  name {
                    field_value {
                      string_value: "my_pipeline"
                    }
                  }
                }
                context_queries {
                  type {
                    name: "component"
                  }
                  name {
                    field_value {
                      string_value: "my_resolver"
                    }
                  }
                }
                artifact_query {
                  type {
                    name: "Model"
                  }
                }
                output_key: "models"
              }
              min_count: 1
            }
          }
        }
        upstream_nodes: "my_resolver"
        """, pipeline_pb2.PipelineNode())
            downstream_input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=down_stream_node.inputs)
            downstream_input_model = downstream_input_artifacts['input_models']
            self.assertLen(downstream_input_model, 1)
            self.assertProtoPartiallyEquals(
                """
          id: 2
          type_id: 5
          uri: "my_model_uri_2"
          state: LIVE""",
                downstream_input_model[0].mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
            [execution] = m.store.get_executions_by_id([execution_metadata.id])

            self.assertProtoPartiallyEquals("""
          id: 2
          type_id: 6
          last_known_state: COMPLETE
          """,
                                            execution,
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
Example #20
0
    def testResolverInputsArtifacts(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(self._testdata_dir,
                         'pipeline_for_input_resolver_test.pbtxt'), pipeline)
        my_example_gen = pipeline.nodes[0].pipeline_node
        another_example_gen = pipeline.nodes[1].pipeline_node
        my_transform = pipeline.nodes[2].pipeline_node
        my_trainer = pipeline.nodes[3].pipeline_node

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        with metadata.Metadata(connection_config=connection_config) as m:
            # Publishes first ExampleGen with two output channels. `output_examples`
            # will be consumed by downstream Transform.
            output_example = types.Artifact(
                my_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            output_example.uri = 'my_examples_uri'
            side_examples = types.Artifact(
                my_example_gen.outputs.outputs['side_examples'].artifact_spec.
                type)
            side_examples.uri = 'side_examples_uri'
            contexts = context_lib.prepare_contexts(m, my_example_gen.contexts)
            execution = execution_publish_utils.register_execution(
                m, my_example_gen.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'output_examples': [output_example],
                    'another_examples': [side_examples]
                })

            # Publishes second ExampleGen with one output channel with the same output
            # key as the first ExampleGen. However this is not consumed by downstream
            # nodes.
            another_output_example = types.Artifact(
                another_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            another_output_example.uri = 'another_examples_uri'
            contexts = context_lib.prepare_contexts(
                m, another_example_gen.contexts)
            execution = execution_publish_utils.register_execution(
                m, another_example_gen.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'output_examples': [another_output_example],
                })

            # Gets inputs for transform. Should get back what the first ExampleGen
            # published in the `output_examples` channel.
            transform_inputs = inputs_utils.resolve_input_artifacts(
                m, my_transform.inputs)
            self.assertEqual(len(transform_inputs), 1)
            self.assertEqual(len(transform_inputs['examples']), 1)
            self.assertProtoPartiallyEquals(
                transform_inputs['examples'][0].mlmd_artifact,
                output_example.mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])

            # Tries to resolve inputs for trainer. As trainer also requires min_count
            # for both input channels (from example_gen and from transform) but we did
            # not publish anything from transform, it should return nothing.
            self.assertIsNone(
                inputs_utils.resolve_input_artifacts(m, my_trainer.inputs))
Example #21
0
    def run(
        self, mlmd_connection: metadata.Metadata,
        pipeline_node: pipeline_pb2.PipelineNode,
        pipeline_info: pipeline_pb2.PipelineInfo,
        pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec
    ) -> data_types.ExecutionInfo:
        """Runs Resolver specific logic.

    Args:
      mlmd_connection: ML metadata connection.
      pipeline_node: The specification of the node that this launcher lauches.
      pipeline_info: The information of the pipeline that this node runs in.
      pipeline_runtime_spec: The runtime information of the pipeline that this
        node runs in.

    Returns:
      The execution of the run.
    """
        logging.info('Running as an resolver node.')
        with mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.prepare_contexts(
                metadata_handler=m, node_contexts=pipeline_node.contexts)

            # 2. Resolves inputs and execution properties.
            exec_properties = data_types_utils.build_parsed_value_dict(
                inputs_utils.resolve_parameters_with_schema(
                    node_parameters=pipeline_node.parameters))
            try:
                resolved_inputs = inputs_utils.resolve_input_artifacts_v2(
                    pipeline_node=pipeline_node, metadata_handler=m)
            except exceptions.InputResolutionError as e:
                execution = execution_publish_utils.register_execution(
                    metadata_handler=m,
                    execution_type=pipeline_node.node_info.type,
                    contexts=contexts,
                    exec_properties=exec_properties)
                execution_publish_utils.publish_failed_execution(
                    metadata_handler=m,
                    contexts=contexts,
                    execution_id=execution.id,
                    executor_output=self._build_error_output(
                        code=e.grpc_code_value))
                return data_types.ExecutionInfo(
                    execution_id=execution.id,
                    exec_properties=exec_properties,
                    pipeline_node=pipeline_node,
                    pipeline_info=pipeline_info)

            # 2a. If Skip (i.e. inside conditional), no execution should be made.
            # TODO(b/197907821): Publish special execution for Skip?
            if isinstance(resolved_inputs, inputs_utils.Skip):
                return data_types.ExecutionInfo()

            # 3. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=pipeline_node.node_info.type,
                contexts=contexts,
                exec_properties=exec_properties)

            # TODO(b/197741942): Support len > 1.
            if len(resolved_inputs) > 1:
                execution_publish_utils.publish_failed_execution(
                    metadata_handler=m,
                    contexts=contexts,
                    execution_id=execution.id,
                    executor_output=self._build_error_output(
                        _ERROR_CODE_UNIMPLEMENTED,
                        'Handling more than one input dicts not implemented yet.'
                    ))
                return data_types.ExecutionInfo(
                    execution_id=execution.id,
                    exec_properties=exec_properties,
                    pipeline_node=pipeline_node,
                    pipeline_info=pipeline_info)

            input_artifacts = resolved_inputs[0]

            # 4. Publish the execution as a cached execution with
            # resolved input artifact as the output artifacts.
            execution_publish_utils.publish_internal_execution(
                metadata_handler=m,
                contexts=contexts,
                execution_id=execution.id,
                output_artifacts=input_artifacts)

            return data_types.ExecutionInfo(execution_id=execution.id,
                                            input_dict=input_artifacts,
                                            output_dict=input_artifacts,
                                            exec_properties=exec_properties,
                                            pipeline_node=pipeline_node,
                                            pipeline_info=pipeline_info)
Example #22
0
    def run(
        self, mlmd_connection: metadata.Metadata,
        pipeline_node: pipeline_pb2.PipelineNode,
        pipeline_info: pipeline_pb2.PipelineInfo,
        pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec
    ) -> data_types.ExecutionInfo:
        """Runs Importer specific logic.

    Args:
      mlmd_connection: ML metadata connection.
      pipeline_node: The specification of the node that this launcher lauches.
      pipeline_info: The information of the pipeline that this node runs in.
      pipeline_runtime_spec: The runtime information of the pipeline that this
        node runs in.

    Returns:
      The execution of the run.
    """
        logging.info('Running as an importer node.')
        with mlmd_connection as m:
            # 1.Prepares all contexts.
            contexts = context_lib.prepare_contexts(
                metadata_handler=m, node_contexts=pipeline_node.contexts)

            # 2. Resolves execution properties, please note that importers has no
            # input.
            exec_properties = data_types_utils.build_parsed_value_dict(
                inputs_utils.resolve_parameters_with_schema(
                    node_parameters=pipeline_node.parameters))

            # 3. Registers execution in metadata.
            execution = execution_publish_utils.register_execution(
                metadata_handler=m,
                execution_type=pipeline_node.node_info.type,
                contexts=contexts,
                exec_properties=exec_properties)

            # 4. Generate output artifacts to represent the imported artifacts.
            output_spec = pipeline_node.outputs.outputs[
                importer.IMPORT_RESULT_KEY]
            properties = self._extract_proto_map(
                output_spec.artifact_spec.additional_properties)
            custom_properties = self._extract_proto_map(
                output_spec.artifact_spec.additional_custom_properties)
            output_artifact_class = types.Artifact(
                output_spec.artifact_spec.type).type
            output_artifacts = importer.generate_output_dict(
                metadata_handler=m,
                uri=str(exec_properties[importer.SOURCE_URI_KEY]),
                properties=properties,
                custom_properties=custom_properties,
                reimport=bool(exec_properties[importer.REIMPORT_OPTION_KEY]),
                output_artifact_class=output_artifact_class,
                mlmd_artifact_type=output_spec.artifact_spec.type)

            result = data_types.ExecutionInfo(execution_id=execution.id,
                                              input_dict={},
                                              output_dict=output_artifacts,
                                              exec_properties=exec_properties,
                                              pipeline_node=pipeline_node,
                                              pipeline_info=pipeline_info)

            # TODO(b/182316162): consider let the launcher level do the publish
            # for system nodes. So that the version taging logic doesn't need to be
            # handled per system node.
            outputs_utils.tag_output_artifacts_with_version(result.output_dict)

            # 5. Publish the output artifacts. If artifacts are reimported, the
            # execution is published as CACHED. Otherwise it is published as COMPLETE.
            if _is_artifact_reimported(output_artifacts):
                execution_publish_utils.publish_cached_execution(
                    metadata_handler=m,
                    contexts=contexts,
                    execution_id=execution.id,
                    output_artifacts=output_artifacts)

            else:
                execution_publish_utils.publish_succeeded_execution(
                    metadata_handler=m,
                    execution_id=execution.id,
                    contexts=contexts,
                    output_artifacts=output_artifacts)

            return result
Example #23
0
    def testRegisterContexts(self):
        node_contexts = pipeline_pb2.NodeContexts()
        self.load_proto_from_text(
            os.path.join(self._testdata_dir, 'node_context_spec.pbtxt'),
            node_contexts)
        with metadata.Metadata(connection_config=self._connection_config) as m:
            context_lib.prepare_contexts(metadata_handler=m,
                                         node_contexts=node_contexts)
            # Duplicated call should succeed.
            contexts = context_lib.prepare_contexts(
                metadata_handler=m, node_contexts=node_contexts)

            self.assertProtoEquals(
                """
          id: 1
          name: 'my_context_type_one'
          """, m.store.get_context_type('my_context_type_one'))
            self.assertProtoEquals(
                """
          id: 2
          name: 'my_context_type_two'
          """, m.store.get_context_type('my_context_type_two'))
            self.assertProtoEquals(
                """
          type_id: 1
          name: "my_context_one"
          custom_properties {
            key: "property_a"
            value {
              int_value: 1
            }
          }
          """, contexts[0])
            self.assertProtoEquals(
                """
          type_id: 1
          name: "my_context_two"
          custom_properties {
            key: "property_a"
            value {
              int_value: 2
            }
          }
          """, contexts[1])
            self.assertProtoEquals(
                """
          type_id: 2
          name: "my_context_three"
          custom_properties {
            key: "property_a"
            value {
              int_value: 3
            }
          }
          custom_properties {
            key: "property_b"
            value {
              string_value: '4'
            }
          }
          """, contexts[2])
            self.assertEqual(
                contexts[0].custom_properties['property_a'].int_value, 1)
            self.assertEqual(
                contexts[1].custom_properties['property_a'].int_value, 2)
            self.assertEqual(
                contexts[2].custom_properties['property_a'].int_value, 3)
            self.assertEqual(
                contexts[2].custom_properties['property_b'].string_value, '4')
Example #24
0
  def run(
      self, mlmd_connection: metadata.Metadata,
      pipeline_node: pipeline_pb2.PipelineNode,
      pipeline_info: pipeline_pb2.PipelineInfo,
      pipeline_runtime_spec: pipeline_pb2.PipelineRuntimeSpec
  ) -> data_types.ExecutionInfo:
    """Runs Importer specific logic.

    Args:
      mlmd_connection: ML metadata connection.
      pipeline_node: The specification of the node that this launcher lauches.
      pipeline_info: The information of the pipeline that this node runs in.
      pipeline_runtime_spec: The runtime information of the pipeline that this
        node runs in.

    Returns:
      The execution of the run.
    """
    logging.info('Running as an importer node.')
    with mlmd_connection as m:
      # 1.Prepares all contexts.
      contexts = context_lib.prepare_contexts(
          metadata_handler=m, node_contexts=pipeline_node.contexts)

      # 2. Resolves execution properties, please note that importers has no
      # input.
      exec_properties = inputs_utils.resolve_parameters(
          node_parameters=pipeline_node.parameters)

      # 3. Registers execution in metadata.
      execution = execution_publish_utils.register_execution(
          metadata_handler=m,
          execution_type=pipeline_node.node_info.type,
          contexts=contexts,
          exec_properties=exec_properties)

      # 4. Generate output artifacts to represent the imported artifacts.
      output_spec = pipeline_node.outputs.outputs[importer.IMPORT_RESULT_KEY]
      properties = self._extract_proto_map(
          output_spec.artifact_spec.additional_properties)
      custom_properties = self._extract_proto_map(
          output_spec.artifact_spec.additional_custom_properties)
      output_artifact_class = types.Artifact(
          output_spec.artifact_spec.type).type
      output_artifacts = importer.generate_output_dict(
          metadata_handler=m,
          uri=str(exec_properties[importer.SOURCE_URI_KEY]),
          properties=properties,
          custom_properties=custom_properties,
          reimport=bool(exec_properties[importer.REIMPORT_OPTION_KEY]),
          output_artifact_class=output_artifact_class,
          mlmd_artifact_type=output_spec.artifact_spec.type)

      # 5. Publish the output artifacts.
      execution_publish_utils.publish_succeeded_execution(
          metadata_handler=m,
          execution_id=execution.id,
          contexts=contexts,
          output_artifacts=output_artifacts)

      return data_types.ExecutionInfo(
          execution_id=execution.id,
          input_dict={},
          output_dict=output_artifacts,
          exec_properties=exec_properties,
          pipeline_node=pipeline_node,
          pipeline_info=pipeline_info)