Ejemplo n.º 1
0
 def run_executor(
     self, execution_info: data_types.ExecutionInfo
 ) -> execution_result_pb2.ExecutorOutput:
     self._exec_properties = execution_info.exec_properties
     # The Launcher expects the ExecutorOperator to return an ExecutorOutput
     # structure with the output artifact information for MLMD publishing.
     output_dict = copy.deepcopy(execution_info.output_dict)
     result = execution_result_pb2.ExecutorOutput()
     for key, artifact_list in output_dict.items():
         artifacts = execution_result_pb2.ExecutorOutput.ArtifactList()
         for artifact in artifact_list:
             artifacts.artifacts.append(artifact.mlmd_artifact)
         result.output_artifacts[key].CopyFrom(artifacts)
     return result
Ejemplo n.º 2
0
    def run_executor(
        self, execution_info: data_types.ExecutionInfo
    ) -> execution_result_pb2.ExecutorOutput:
        """Invokers executors given input from the Launcher.

    Args:
      execution_info: A wrapper of the details of this execution.

    Returns:
      The output from executor.
    """
        # TODO(b/156000550): We should not specialize `Context` to embed beam
        # pipeline args. Instead, the `Context` should consists of generic purpose
        # `extra_flags` which can be interpreted differently by different
        # implementations of executors.
        context = base_executor.BaseExecutor.Context(
            beam_pipeline_args=self.extra_flags,
            tmp_dir=execution_info.tmp_dir,
            unique_id=str(execution_info.execution_id),
            executor_output_uri=execution_info.execution_output_uri,
            stateful_working_dir=execution_info.stateful_working_dir)
        executor = self._executor_cls(context=context)

        for _, artifact_list in execution_info.input_dict.items():
            for artifact in artifact_list:
                if isinstance(artifact, ValueArtifact):
                    # Read ValueArtifact into memory.
                    artifact.read()

        result = executor.Do(execution_info.input_dict,
                             execution_info.output_dict,
                             execution_info.exec_properties)
        if not result:
            # If result is not returned from the Do function, then try to
            # read if from the executor_output_uri.
            try:
                with fileio.open(execution_info.execution_output_uri,
                                 'rb') as f:
                    result = execution_result_pb2.ExecutorOutput.FromString(
                        f.read())
            except tf.errors.NotFoundError:
                # Old style TFX executor doesn't return executor_output, but modify
                # output_dict and exec_properties in place. For backward compatibility,
                # we use their executor_output and exec_properties to construct
                # ExecutorOutput.
                result = execution_result_pb2.ExecutorOutput()
                _populate_output_artifact(result, execution_info.output_dict)
                _populate_exec_properties(result,
                                          execution_info.exec_properties)
        return result
Ejemplo n.º 3
0
 def testPopulateOutputArtifact(self):
     executor_output = execution_result_pb2.ExecutorOutput()
     output_dict = {'output_key': [standard_artifacts.Model()]}
     outputs_utils.populate_output_artifact(executor_output, output_dict)
     self.assertProtoEquals(
         """
     output_artifacts {
       key: "output_key"
       value {
         artifacts {
         }
       }
     }
     """, executor_output)
Ejemplo n.º 4
0
    def testPublishSuccessExecutionFailChangedType(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            execution_id = execution_publish_utils.register_execution(
                m, self._execution_type, contexts).id
            executor_output = execution_result_pb2.ExecutorOutput()
            executor_output.output_artifacts['examples'].artifacts.add(
            ).type_id = 10

            with self.assertRaisesRegex(RuntimeError, 'change artifact type'):
                execution_publish_utils.publish_succeeded_execution(
                    m, execution_id, contexts,
                    {'examples': [
                        standard_artifacts.Examples(),
                    ]}, executor_output)
Ejemplo n.º 5
0
 def testPublishSuccessExecutionUpdatesCustomProperties(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         executor_output = text_format.Parse(
             """
       execution_properties {
       key: "int"
       value {
         int_value: 1
       }
       }
       execution_properties {
         key: "string"
         value {
           string_value: "string_value"
         }
       }
        """, execution_result_pb2.ExecutorOutput())
         contexts = self._generate_contexts(m)
         execution_id = execution_publish_utils.register_execution(
             m, self._execution_type, contexts).id
         execution_publish_utils.publish_succeeded_execution(
             m, execution_id, contexts, {}, executor_output)
         [execution] = m.store.get_executions_by_id([execution_id])
         self.assertProtoPartiallyEquals("""
       id: 1
       last_known_state: COMPLETE
       custom_properties {
         key: "int"
         value {
           int_value: 1
         }
       }
       custom_properties {
         key: "string"
         value {
           string_value: "string_value"
         }
       }
       """,
                                         execution,
                                         ignored_fields=[
                                             'type_id',
                                             'create_time_since_epoch',
                                             'last_update_time_since_epoch'
                                         ])
Ejemplo n.º 6
0
 def run_executor(
     self, execution_info: data_types.ExecutionInfo
 ) -> execution_result_pb2.ExecutorOutput:
     output_dict = copy.deepcopy(execution_info.output_dict)
     result = execution_result_pb2.ExecutorOutput()
     for key, artifact_list in output_dict.items():
         artifacts = execution_result_pb2.ExecutorOutput.ArtifactList()
         for artifact in artifact_list:
             artifacts.artifacts.append(artifact.mlmd_artifact)
         result.output_artifacts[key].CopyFrom(artifacts)
     # Although the following removing is typically not expected, but there is
     # no way to prevent them from happening. We should make sure that the
     # launcher can handle the double cleanup gracefully.
     fileio.rmtree(
         os.path.abspath(
             os.path.join(execution_info.stateful_working_dir, os.pardir)))
     fileio.rmtree(execution_info.tmp_dir)
     return result
Ejemplo n.º 7
0
    def testPublishSuccessExecutionFailChangedUriDir(self):
        output_example = standard_artifacts.Examples()
        output_example.uri = '/my/original_uri'
        output_dict = {'examples': [output_example]}
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            execution_id = execution_publish_utils.register_execution(
                m, self._execution_type, contexts).id
            executor_output = execution_result_pb2.ExecutorOutput()
            new_example = executor_output.output_artifacts[
                'examples'].artifacts.add()
            new_example.uri = '/my/new_uri/1'

            with self.assertRaisesRegex(
                    RuntimeError,
                    'When there is one artifact to publish, the URI of it should be '
                    'identical to the URI of system generated artifact.'):
                execution_publish_utils.publish_succeeded_execution(
                    m, execution_id, contexts, output_dict, executor_output)
Ejemplo n.º 8
0
def run_with_executor(
    execution_info: data_types.ExecutionInfo,
    executor: base_executor.BaseExecutor
) -> execution_result_pb2.ExecutorOutput:
    """Invokes executors given an executor instance and input from the Launcher.

  Args:
    execution_info: A wrapper of the details of this execution.
    executor: An executor instance.

  Returns:
    The output from executor.
  """
    # In cases where output directories are not empty due to a previous or
    # unrelated execution, clear the directories to ensure consistency.
    outputs_utils.clear_output_dirs(execution_info.output_dict)

    for _, artifact_list in execution_info.input_dict.items():
        for artifact in artifact_list:
            if isinstance(artifact, ValueArtifact):
                # Read ValueArtifact into memory.
                artifact.read()

    output_dict = copy.deepcopy(execution_info.output_dict)
    result = executor.Do(execution_info.input_dict, output_dict,
                         execution_info.exec_properties)
    if not result:
        # If result is not returned from the Do function, then try to
        # read from the executor_output_uri.
        if fileio.exists(execution_info.execution_output_uri):
            result = execution_result_pb2.ExecutorOutput.FromString(
                fileio.open(execution_info.execution_output_uri, 'rb').read())
        else:
            # Old style TFX executor doesn't return executor_output, but modify
            # output_dict and exec_properties in place. For backward compatibility,
            # we use their executor_output and exec_properties to construct
            # ExecutorOutput.
            result = execution_result_pb2.ExecutorOutput()
            outputs_utils.populate_output_artifact(result, output_dict)
            outputs_utils.populate_exec_properties(
                result, execution_info.exec_properties)
    return result
Ejemplo n.º 9
0
 def testRunExecutor_with_InplaceUpdateExecutor(self):
     executor_sepc = text_format.Parse(
         """
   class_path: "tfx.orchestration.portable.python_executor_operator_test.InplaceUpdateExecutor"
 """, executable_spec_pb2.PythonClassExecutableSpec())
     operator = python_executor_operator.PythonExecutorOperator(
         executor_sepc)
     input_dict = {'input_key': [standard_artifacts.Examples()]}
     output_dict = {'output_key': [standard_artifacts.Model()]}
     exec_properties = {
         'string': 'value',
         'int': 1,
         'float': 0.0,
         # This should not happen on production and will be
         # dropped.
         'proto': execution_result_pb2.ExecutorOutput()
     }
     stateful_working_dir = os.path.join(self.tmp_dir,
                                         'stateful_working_dir')
     executor_output_uri = os.path.join(self.tmp_dir, 'executor_output')
     executor_output = operator.run_executor(
         data_types.ExecutionInfo(execution_id=1,
                                  input_dict=input_dict,
                                  output_dict=output_dict,
                                  exec_properties=exec_properties,
                                  stateful_working_dir=stateful_working_dir,
                                  execution_output_uri=executor_output_uri))
     self.assertProtoPartiallyEquals(
         """
       output_artifacts {
         key: "output_key"
         value {
           artifacts {
             custom_properties {
               key: "name"
               value {
                 string_value: "my_model"
               }
             }
           }
         }
       }""", executor_output)
Ejemplo n.º 10
0
 def testPopulateExecProperties(self):
     executor_output = execution_result_pb2.ExecutorOutput()
     exec_properties = {'string_value': 'string', 'int_value': 1}
     outputs_utils.populate_exec_properties(executor_output,
                                            exec_properties)
     self.assertProtoEquals(
         """
     execution_properties {
       key: "string_value"
       value {
         string_value: "string"
       }
     }
     execution_properties {
       key: "int_value"
       value {
         int_value: 1
       }
     }
     """, executor_output)
Ejemplo n.º 11
0
    def testPublishSuccessExecutionFailTooManyLayerOfSubDir(self):
        output_example = standard_artifacts.Examples()
        output_example.uri = '/my/original_uri'
        output_dict = {'examples': [output_example]}
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            execution_id = execution_publish_utils.register_execution(
                m, self._execution_type, contexts).id
            executor_output = execution_result_pb2.ExecutorOutput()
            new_example = executor_output.output_artifacts[
                'examples'].artifacts.add()
            new_example.uri = '/my/original_uri/1/1'

            with self.assertRaisesRegex(
                    RuntimeError,
                    'The URI of executor generated artifacts should either be identical '
                    'to the URI of system generated artifact or be a direct sub-dir of '
                    'it.'):
                execution_publish_utils.publish_succeeded_execution(
                    m, execution_id, contexts, output_dict, executor_output)
Ejemplo n.º 12
0
 def testPublishFailedExecution(self):
   with metadata.Metadata(connection_config=self._connection_config) as m:
     executor_output = text_format.Parse(
         """
       execution_result {
         code: 1
         result_message: 'error message.'
        }
     """, execution_result_pb2.ExecutorOutput())
     contexts = self._generate_contexts(m)
     execution_id = execution_publish_utils.register_execution(
         m, self._execution_type, contexts).id
     execution_publish_utils.publish_failed_execution(m, contexts,
                                                      execution_id,
                                                      executor_output)
     [execution] = m.store.get_executions_by_id([execution_id])
     self.assertProtoPartiallyEquals(
         """
         id: 1
         last_known_state: FAILED
         custom_properties {
           key: '__execution_result__'
           value {
             string_value: '{\\n  "resultMessage": "error message.",\\n  "code": 1\\n}'
           }
         }
         """,
         execution,
         ignored_fields=[
             'type_id', 'create_time_since_epoch',
             'last_update_time_since_epoch'
         ])
     # No events because there is no artifact published.
     events = m.store.get_events_by_execution_ids([execution.id])
     self.assertEmpty(events)
     # Verifies the context-execution edges are set up.
     self.assertCountEqual(
         [c.id for c in contexts],
         [c.id for c in m.store.get_contexts_by_execution(execution.id)])
Ejemplo n.º 13
0
    def run_executor(
        self, execution_info: base_executor_operator.ExecutionInfo
    ) -> execution_result_pb2.ExecutorOutput:
        """Invokers executors given input from the Launcher.

    Args:
      execution_info: A wrapper of the details of this execution.

    Returns:
      The output from executor.
    """
        # TODO(b/162980675): Set arguments for Beam when it is available.
        context = base_executor.BaseExecutor.Context(
            executor_output_uri=execution_info.executor_output_uri,
            stateful_working_dir=execution_info.stateful_working_dir)
        executor = self._executor_cls(context=context)

        result = executor.Do(execution_info.input_dict,
                             execution_info.output_dict,
                             execution_info.exec_properties)
        if not result:
            # If result is not returned from the Do function, then try to
            # read if from the executor_output_uri.
            try:
                with tf.io.gfile.GFile(execution_info.executor_output_uri,
                                       'rb') as f:
                    result = execution_result_pb2.ExecutorOutput.FromString(
                        f.read())
            except tf.errors.NotFoundError:
                # Old style TFX executor doesn't return executor_output, but modify
                # output_dict and exec_properties in place. For backward compatibility,
                # we use their executor_output and exec_properties to construct
                # ExecutorOutput.
                result = execution_result_pb2.ExecutorOutput()
                _populate_output_artifact(result, execution_info.output_dict)
                _populate_exec_properties(result,
                                          execution_info.exec_properties)
        return result
Ejemplo n.º 14
0
    def testPublishSuccessExecutionFailInvalidUri(self, invalid_uri):
        output_example = standard_artifacts.Examples()
        output_example.uri = '/my/original_uri'
        output_dict = {'examples': [output_example]}
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            execution_id = execution_publish_utils.register_execution(
                m, self._execution_type, contexts).id
            executor_output = execution_result_pb2.ExecutorOutput()
            system_generated_artifact = executor_output.output_artifacts[
                'examples'].artifacts.add()
            system_generated_artifact.uri = '/my/original_uri/0'
            new_artifact = executor_output.output_artifacts[
                'examples'].artifacts.add()
            new_artifact.uri = invalid_uri

            with self.assertRaisesRegex(
                    RuntimeError,
                    'When there are multiple artifacts to publish, their URIs should be '
                    'direct sub-directories of the URI of the system generated artifact.'
            ):
                execution_publish_utils.publish_succeeded_execution(
                    m, execution_id, contexts, output_dict, executor_output)
Ejemplo n.º 15
0
 def testPublishSuccessExecutionDropsEmptyResult(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         executor_output = text_format.Parse(
             """
     execution_result {
       code: 0
      }
   """, execution_result_pb2.ExecutorOutput())
         contexts = self._generate_contexts(m)
         execution_id = execution_publish_utils.register_execution(
             m, self._execution_type, contexts).id
         execution_publish_utils.publish_failed_execution(
             m, contexts, execution_id, executor_output)
         [execution] = m.store.get_executions_by_id([execution_id])
         self.assertProtoPartiallyEquals("""
       id: 1
       type_id: 3
       last_known_state: FAILED
       """,
                                         execution,
                                         ignored_fields=[
                                             'create_time_since_epoch',
                                             'last_update_time_since_epoch'
                                         ])
Ejemplo n.º 16
0
 def run_executor(
     self, execution_info: base_executor_operator.ExecutionInfo
 ) -> execution_result_pb2.ExecutorOutput:
     self._exec_properties = execution_info.exec_properties
     return execution_result_pb2.ExecutorOutput()
Ejemplo n.º 17
0
 def run_executor(
     self, execution_info: data_types.ExecutionInfo
 ) -> execution_result_pb2.ExecutorOutput:
     self._exec_properties = execution_info.exec_properties
     return execution_result_pb2.ExecutorOutput()
Ejemplo n.º 18
0
 def run_executor(
     self, execution_info: base_executor_operator.ExecutionInfo
 ) -> execution_result_pb2.ExecutorOutput:
     return execution_result_pb2.ExecutorOutput()
Ejemplo n.º 19
0
 def schedule(self):
   return ts.TaskSchedulerResult(
       executor_output=execution_result_pb2.ExecutorOutput())
Ejemplo n.º 20
0
 def testPublishSuccessfulExecution(self):
   with metadata.Metadata(connection_config=self._connection_config) as m:
     contexts = self._generate_contexts(m)
     execution_id = execution_publish_utils.register_execution(
         m, self._execution_type, contexts).id
     output_key = 'examples'
     output_example = standard_artifacts.Examples()
     executor_output = execution_result_pb2.ExecutorOutput()
     text_format.Parse(
         """
         uri: 'examples_uri'
         custom_properties {
           key: 'prop'
           value {int_value: 1}
         }
         """, executor_output.output_artifacts[output_key].artifacts.add())
     execution_publish_utils.publish_succeeded_execution(
         m, execution_id, contexts, {output_key: [output_example]},
         executor_output)
     [execution] = m.store.get_executions()
     self.assertProtoPartiallyEquals(
         """
         id: 1
         type_id: 3
         last_known_state: COMPLETE
         """,
         execution,
         ignored_fields=[
             'create_time_since_epoch', 'last_update_time_since_epoch'
         ])
     [artifact] = m.store.get_artifacts()
     self.assertProtoPartiallyEquals(
         """
         id: 1
         type_id: 4
         state: LIVE
         uri: 'examples_uri'
         custom_properties {
           key: 'prop'
           value {int_value: 1}
         }""",
         artifact,
         ignored_fields=[
             'create_time_since_epoch', 'last_update_time_since_epoch'
         ])
     [event] = m.store.get_events_by_execution_ids([execution.id])
     self.assertProtoPartiallyEquals(
         """
         artifact_id: 1
         execution_id: 1
         path {
           steps {
             key: 'examples'
           }
           steps {
             index: 0
           }
         }
         type: OUTPUT
         """,
         event,
         ignored_fields=['milliseconds_since_epoch'])
     # Verifies the context-execution edges are set up.
     self.assertCountEqual(
         [c.id for c in contexts],
         [c.id for c in m.store.get_contexts_by_execution(execution.id)])
     self.assertCountEqual(
         [c.id for c in contexts],
         [c.id for c in m.store.get_contexts_by_artifact(output_example.id)])
Ejemplo n.º 21
0
    def testPublishSuccessExecutionExecutorEditedOutputDict(self):
        # There is one artifact in the system provided output_dict, while there are
        # two artifacts in executor output. We expect that two artifacts are
        # published.
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            execution_id = execution_publish_utils.register_execution(
                m, self._execution_type, contexts).id

            output_example = standard_artifacts.Examples()
            output_example.uri = '/original_path'

            executor_output = execution_result_pb2.ExecutorOutput()
            output_key = 'examples'
            text_format.Parse(
                """
          uri: '/original_path/subdir_1'
          custom_properties {
            key: 'prop'
            value {int_value: 1}
          }
          """, executor_output.output_artifacts[output_key].artifacts.add())
            text_format.Parse(
                """
          uri: '/original_path/subdir_2'
          custom_properties {
            key: 'prop'
            value {int_value: 2}
          }
          """, executor_output.output_artifacts[output_key].artifacts.add())

            output_dict = execution_publish_utils.publish_succeeded_execution(
                m, execution_id, contexts, {output_key: [output_example]},
                executor_output)
            [execution] = m.store.get_executions()
            self.assertProtoPartiallyEquals("""
          id: 1
          type_id: 3
          last_known_state: COMPLETE
          """,
                                            execution,
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
            artifacts = m.store.get_artifacts()
            self.assertLen(artifacts, 2)
            self.assertProtoPartiallyEquals("""
          id: 1
          type_id: 4
          state: LIVE
          uri: '/original_path/subdir_1'
          custom_properties {
            key: 'prop'
            value {int_value: 1}
          }""",
                                            artifacts[0],
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
            self.assertProtoPartiallyEquals("""
          id: 2
          type_id: 4
          state: LIVE
          uri: '/original_path/subdir_2'
          custom_properties {
            key: 'prop'
            value {int_value: 2}
          }""",
                                            artifacts[1],
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
            events = m.store.get_events_by_execution_ids([execution.id])
            self.assertLen(events, 2)
            self.assertProtoPartiallyEquals(
                """
          artifact_id: 1
          execution_id: 1
          path {
            steps {
              key: 'examples'
            }
            steps {
              index: 0
            }
          }
          type: OUTPUT
          """,
                events[0],
                ignored_fields=['milliseconds_since_epoch'])
            self.assertProtoPartiallyEquals(
                """
          artifact_id: 2
          execution_id: 1
          path {
            steps {
              key: 'examples'
            }
            steps {
              index: 1
            }
          }
          type: OUTPUT
          """,
                events[1],
                ignored_fields=['milliseconds_since_epoch'])
            # Verifies the context-execution edges are set up.
            self.assertCountEqual([c.id for c in contexts], [
                c.id for c in m.store.get_contexts_by_execution(execution.id)
            ])
            for artifact_list in output_dict.values():
                for output_example in artifact_list:
                    self.assertCountEqual([c.id for c in contexts], [
                        c.id for c in m.store.get_contexts_by_artifact(
                            output_example.id)
                    ])
Ejemplo n.º 22
0
    def testOverrideRegisterExecution(self):
        # Mock all real operations of driver / executor / MLMD accesses.
        mock_targets = (  # (cls, method, return_value)
            (beam_executor_operator.BeamExecutorOperator, '__init__', None),
            (beam_executor_operator.BeamExecutorOperator, 'run_executor',
             execution_result_pb2.ExecutorOutput()),
            (python_driver_operator.PythonDriverOperator, '__init__', None),
            (python_driver_operator.PythonDriverOperator, 'run_driver',
             driver_output_pb2.DriverOutput()),
            (metadata.Metadata, '__init__', None),
            (metadata.Metadata, '__exit__', None),
            (launcher.Launcher, '_publish_successful_execution', None),
            (launcher.Launcher, '_clean_up_stateless_execution_info', None),
            (launcher.Launcher, '_clean_up_stateful_execution_info', None),
            (outputs_utils, 'OutputsResolver', mock.MagicMock()),
            (execution_lib, 'get_executions_associated_with_all_contexts', []),
            (container_entrypoint, '_dump_ui_metadata', None),
        )
        for cls, method, return_value in mock_targets:
            self.enter_context(
                mock.patch.object(cls,
                                  method,
                                  autospec=True,
                                  return_value=return_value))

        mock_mlmd = self.enter_context(
            mock.patch.object(metadata.Metadata, '__enter__',
                              autospec=True)).return_value
        mock_mlmd.store.return_value.get_executions_by_id.return_value = [
            metadata_store_pb2.Execution()
        ]

        self._set_required_env_vars({
            'WORKFLOW_ID':
            'workflow-id-42',
            'METADATA_GRPC_SERVICE_HOST':
            'metadata-grpc',
            'METADATA_GRPC_SERVICE_PORT':
            '8080',
            container_entrypoint._KFP_POD_NAME_ENV_KEY:
            'test_pod_name'
        })

        mock_register_execution = self.enter_context(
            mock.patch.object(execution_publish_utils,
                              'register_execution',
                              autospec=True))

        test_ir_file = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), 'testdata',
            'two_step_pipeline_post_dehydrate_ir.json')
        test_ir = io_utils.read_string_file(test_ir_file)

        argv = [
            '--pipeline_root',
            'dummy',
            '--kubeflow_metadata_config',
            json_format.MessageToJson(
                kubeflow_dag_runner.get_default_kubeflow_metadata_config()),
            '--tfx_ir',
            test_ir,
            '--node_id',
            'BigQueryExampleGen',
            '--runtime_parameter',
            'pipeline-run-id=STRING:my-run-id',
        ]
        container_entrypoint.main(argv)

        mock_register_execution.assert_called_once()
        kwargs = mock_register_execution.call_args[1]
        self.assertEqual(
            kwargs['exec_properties']
            [container_entrypoint._KFP_POD_NAME_PROPERTY_KEY], 'test_pod_name')
Ejemplo n.º 23
0
    def run_executor(
        self, execution_info: data_types.ExecutionInfo
    ) -> execution_result_pb2.ExecutorOutput:
        """Execute underlying component implementation.

    Runs executor container in a Kubernetes Pod and wait until it goes into
    `Succeeded` or `Failed` state.

    Args:
      execution_info: All the information that the launcher provides.

    Raises:
      RuntimeError: when the pod is in `Failed` state or unexpected failure from
      Kubernetes API.

    Returns:
      An ExecutorOutput instance

    """

        context = placeholder_utils.ResolutionContext(
            exec_info=execution_info,
            executor_spec=self._executor_spec,
            platform_config=self._platform_config)

        container_spec = executor_specs.TemplatedExecutorContainerSpec(
            image=self._container_executor_spec.image,
            command=[
                placeholder_utils.resolve_placeholder_expression(cmd, context)
                for cmd in self._container_executor_spec.commands
            ] or None,
            args=[
                placeholder_utils.resolve_placeholder_expression(arg, context)
                for arg in self._container_executor_spec.args
            ] or None,
        )

        pod_name = self._build_pod_name(execution_info)
        # TODO(hongyes): replace the default value from component config.
        try:
            namespace = kube_utils.get_kfp_namespace()
        except RuntimeError:
            namespace = 'kubeflow'

        pod_manifest = self._build_pod_manifest(pod_name, container_spec)
        core_api = kube_utils.make_core_v1_api()

        if kube_utils.is_inside_kfp():
            launcher_pod = kube_utils.get_current_kfp_pod(core_api)
            pod_manifest['spec'][
                'serviceAccount'] = launcher_pod.spec.service_account
            pod_manifest['spec'][
                'serviceAccountName'] = launcher_pod.spec.service_account_name
            pod_manifest['metadata'][
                'ownerReferences'] = container_common.to_swagger_dict(
                    launcher_pod.metadata.owner_references)
        else:
            pod_manifest['spec'][
                'serviceAccount'] = kube_utils.TFX_SERVICE_ACCOUNT
            pod_manifest['spec'][
                'serviceAccountName'] = kube_utils.TFX_SERVICE_ACCOUNT

        logging.info('Looking for pod "%s:%s".', namespace, pod_name)
        resp = kube_utils.get_pod(core_api, pod_name, namespace)
        if not resp:
            logging.info('Pod "%s:%s" does not exist. Creating it...',
                         namespace, pod_name)
            logging.info('Pod manifest: %s', pod_manifest)
            try:
                resp = core_api.create_namespaced_pod(namespace=namespace,
                                                      body=pod_manifest)
            except client.rest.ApiException as e:
                raise RuntimeError(
                    'Failed to created container executor pod!\nReason: %s\nBody: %s'
                    % (e.reason, e.body))

        # Wait up to 300 seconds for the pod to move from pending to another status.
        logging.info('Waiting for pod "%s:%s" to start.', namespace, pod_name)
        kube_utils.wait_pod(
            core_api,
            pod_name,
            namespace,
            exit_condition_lambda=kube_utils.pod_is_not_pending,
            condition_description='non-pending status',
            timeout_sec=300)

        logging.info('Start log streaming for pod "%s:%s".', namespace,
                     pod_name)
        try:
            logs = core_api.read_namespaced_pod_log(
                name=pod_name,
                namespace=namespace,
                container=kube_utils.ARGO_MAIN_CONTAINER_NAME,
                follow=True,
                _preload_content=False).stream()
        except client.rest.ApiException as e:
            raise RuntimeError(
                'Failed to stream the logs from the pod!\nReason: %s\nBody: %s'
                % (e.reason, e.body))

        for log in logs:
            logging.info(log.decode().rstrip('\n'))

        # Wait indefinitely for the pod to complete.
        resp = kube_utils.wait_pod(
            core_api,
            pod_name,
            namespace,
            exit_condition_lambda=kube_utils.pod_is_done,
            condition_description='done state')

        if resp.status.phase == kube_utils.PodPhase.FAILED.value:
            raise RuntimeError('Pod "%s:%s" failed with status "%s".' %
                               (namespace, pod_name, resp.status))

        logging.info('Pod "%s:%s" is done.', namespace, pod_name)

        return execution_result_pb2.ExecutorOutput()