예제 #1
0
파일: compiler.py 프로젝트: jay90099/tfx
  def _compile_node_outputs(self, tfx_node: base_node.BaseNode,
                            node_pb: pipeline_pb2.PipelineNode):
    """Compiles the outputs of a node/component."""
    for key, value in tfx_node.outputs.items():
      output_spec = node_pb.outputs.outputs[key]
      artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
      output_spec.artifact_spec.type.CopyFrom(artifact_type)

      # Attach additional properties for artifacts produced by importer nodes.
      for property_name, property_value in value.additional_properties.items():
        _check_property_value_type(property_name, property_value, artifact_type)
        value_field = output_spec.artifact_spec.additional_properties[
            property_name].field_value
        try:
          data_types_utils.set_metadata_value(value_field, property_value)
        except ValueError:
          raise ValueError(
              "Component {} got unsupported parameter {} with type {}.".format(
                  tfx_node.id, property_name,
                  type(property_value))) from ValueError

      for property_name, property_value in (
          value.additional_custom_properties.items()):
        value_field = output_spec.artifact_spec.additional_custom_properties[
            property_name].field_value
        try:
          data_types_utils.set_metadata_value(value_field, property_value)
        except ValueError:
          raise ValueError(
              "Component {} got unsupported parameter {} with type {}.".format(
                  tfx_node.id, property_name,
                  type(property_value))) from ValueError
예제 #2
0
def prepare_execution(
    metadata_handler: metadata.Metadata,
    execution_type: metadata_store_pb2.ExecutionType,
    state: metadata_store_pb2.Execution.State,
    exec_properties: Optional[Mapping[Text, types.Property]] = None,
) -> metadata_store_pb2.Execution:
    """Creates an execution proto based on the information provided.

  Args:
    metadata_handler: A handler to access MLMD store.
    execution_type: A metadata_pb2.ExecutionType message describing the type of
      the execution.
    state: The state of the execution.
    exec_properties: Execution properties that need to be attached.

  Returns:
    A metadata_store_pb2.Execution message.
  """
    execution = metadata_store_pb2.Execution()
    execution.last_known_state = state
    execution.type_id = common_utils.register_type_if_not_exist(
        metadata_handler, execution_type).id

    exec_properties = exec_properties or {}
    # For every execution property, put it in execution.properties if its key is
    # in execution type schema. Otherwise, put it in execution.custom_properties.
    for k, v in exec_properties.items():
        if (execution_type.properties.get(k) ==
                data_types_utils.get_metadata_value_type(v)):
            data_types_utils.set_metadata_value(execution.properties[k], v)
        else:
            data_types_utils.set_metadata_value(execution.custom_properties[k],
                                                v)
    logging.debug('Prepared EXECUTION:\n %s', execution)
    return execution
예제 #3
0
def substitute_runtime_parameter(
        msg: message.Message,
        parameter_bindings: Mapping[str, types.Property]) -> None:
    """Utility function to substitute runtime parameter placeholders with values.

  Args:
    msg: The original message to change. Only messages defined under
      pipeline_pb2 will be supported. Other types will result in no-op.
    parameter_bindings: A dict of parameter keys to parameter values that will
      be used to substitute the runtime parameter placeholder.

  Returns:
    None
  """
    if not isinstance(msg, message.Message):
        return

    # If the message is a pipeline_pb2.Value instance, try to find an substitute
    # with runtime parameter bindings.
    if isinstance(msg, pipeline_pb2.Value):
        value = cast(pipeline_pb2.Value, msg)
        which = value.WhichOneof('value')
        if which == 'runtime_parameter':
            real_value = _get_runtime_parameter_value(value.runtime_parameter,
                                                      parameter_bindings)
            if real_value is None:
                return
            value.Clear()
            data_types_utils.set_metadata_value(
                metadata_value=value.field_value, value=real_value)
        if which == 'structural_runtime_parameter':
            real_value = _get_structural_runtime_parameter_value(
                value.structural_runtime_parameter, parameter_bindings)
            if real_value is None:
                return
            value.Clear()
            data_types_utils.set_metadata_value(
                metadata_value=value.field_value, value=real_value)

        return

    # For other cases, recursively call into sub-messages if any.
    for field, sub_message in msg.ListFields():
        # No-op for non-message types.
        if field.type != descriptor.FieldDescriptor.TYPE_MESSAGE:
            continue
        # Evaluates every map values in a map.
        elif (field.message_type.has_options
              and field.message_type.GetOptions().map_entry):
            for key in sub_message:
                substitute_runtime_parameter(sub_message[key],
                                             parameter_bindings)
        # Evaluates every entry in a list.
        elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
            for element in sub_message:
                substitute_runtime_parameter(element, parameter_bindings)
        # Evaluates sub-message.
        else:
            substitute_runtime_parameter(sub_message, parameter_bindings)
예제 #4
0
    def test_pipeline_failure_strategies(self, fail_fast):
        """Tests pipeline failure strategies."""
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        self._run_next(False,
                       expect_nodes=[self._stats_gen],
                       fail_fast=fail_fast)
        self._run_next(False,
                       expect_nodes=[self._schema_gen],
                       fail_fast=fail_fast)

        # Both example-validator and transform are ready to execute.
        [example_validator_task,
         transform_task] = self._generate(False, True, fail_fast=fail_fast)
        self.assertEqual(self._example_validator.node_info.id,
                         example_validator_task.node_uid.node_id)
        self.assertEqual(self._transform.node_info.id,
                         transform_task.node_uid.node_id)

        # Simulate Transform success.
        self._finish_node_execution(False, transform_task)

        # But fail example-validator.
        with self._mlmd_connection as m:
            with mlmd_state.mlmd_execution_atomic_op(
                    m, example_validator_task.execution_id) as ev_exec:
                # Fail stats-gen execution.
                ev_exec.last_known_state = metadata_store_pb2.Execution.FAILED
                data_types_utils.set_metadata_value(
                    ev_exec.custom_properties[
                        constants.EXECUTION_ERROR_MSG_KEY],
                    'example-validator error')

        if fail_fast:
            # Pipeline run should immediately fail because example-validator failed.
            [finalize_task] = self._generate(False, True, fail_fast=fail_fast)
            self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task))
            self.assertEqual(status_lib.Code.ABORTED,
                             finalize_task.status.code)
        else:
            # Trainer and downstream nodes can execute as transform has finished.
            # example-validator failure does not impact them as it is not upstream.
            # Pipeline run will still fail but when no more progress can be made.
            self._run_next(False,
                           expect_nodes=[self._trainer],
                           fail_fast=fail_fast)
            self._run_next(False,
                           expect_nodes=[self._chore_a],
                           fail_fast=fail_fast)
            self._run_next(False,
                           expect_nodes=[self._chore_b],
                           fail_fast=fail_fast)
            [finalize_task] = self._generate(False, True, fail_fast=fail_fast)
            self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task))
            self.assertEqual(status_lib.Code.ABORTED,
                             finalize_task.status.code)
예제 #5
0
def _update_execution_state_in_mlmd(
    mlmd_handle: metadata.Metadata, execution: metadata_store_pb2.Execution,
    new_state: metadata_store_pb2.Execution.State, error_msg: str) -> None:
  updated_execution = copy.deepcopy(execution)
  updated_execution.last_known_state = new_state
  if error_msg:
    data_types_utils.set_metadata_value(
        updated_execution.custom_properties[constants.EXECUTION_ERROR_MSG_KEY],
        error_msg)
  mlmd_handle.store.put_executions([updated_execution])
예제 #6
0
def _update_execution_state_in_mlmd(
        mlmd_handle: metadata.Metadata, execution_id: int,
        new_state: metadata_store_pb2.Execution.State, error_msg: str) -> None:
    with mlmd_state.mlmd_execution_atomic_op(mlmd_handle,
                                             execution_id) as execution:
        execution.last_known_state = new_state
        if error_msg:
            data_types_utils.set_metadata_value(
                execution.custom_properties[constants.EXECUTION_ERROR_MSG_KEY],
                error_msg)
예제 #7
0
 def testSetMetadataValueWithTfxValue(self):
     tfx_value = pipeline_pb2.Value()
     metadata_property = metadata_store_pb2.Value()
     text_format.Parse(
         """
     field_value {
         int_value: 1
     }""", tfx_value)
     data_types_utils.set_metadata_value(metadata_value=metadata_property,
                                         value=tfx_value)
     self.assertProtoEquals('int_value: 1', metadata_property)
예제 #8
0
 def testSetMetadataValueWithTfxValueFailed(self):
     tfx_value = pipeline_pb2.Value()
     metadata_property = metadata_store_pb2.Value()
     text_format.Parse(
         """
     runtime_parameter {
       name: 'rp'
     }""", tfx_value)
     with self.assertRaisesRegex(ValueError,
                                 'Expecting field_value but got'):
         data_types_utils.set_metadata_value(
             metadata_value=metadata_property, value=tfx_value)
예제 #9
0
    def initiate_update(
        self,
        updated_pipeline: pipeline_pb2.Pipeline,
        update_options: pipeline_pb2.UpdateOptions,
    ) -> None:
        """Initiates pipeline update process."""
        self._check_context()

        if self.pipeline.execution_mode != updated_pipeline.execution_mode:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.INVALID_ARGUMENT,
                message=(
                    'Updating execution_mode of an active pipeline is not '
                    'supported'))

        if self.pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            updated_pipeline_run_id = (
                updated_pipeline.runtime_spec.pipeline_run_id.field_value.
                string_value)
            if self.pipeline_run_id != updated_pipeline_run_id:
                raise status_lib.StatusNotOkError(
                    code=status_lib.Code.INVALID_ARGUMENT,
                    message=
                    (f'For sync pipeline, pipeline_run_id should match; found '
                     f'mismatch: {self.pipeline_run_id} (existing) vs. '
                     f'{updated_pipeline_run_id} (updated)'))

        # TODO(b/194311197): We require that structure of the updated pipeline
        # exactly matches the original. There is scope to relax this restriction.

        def _structure(
            pipeline: pipeline_pb2.Pipeline
        ) -> List[Tuple[str, List[str], List[str]]]:
            return [(node.node_info.id, list(node.upstream_nodes),
                     list(node.downstream_nodes))
                    for node in get_all_pipeline_nodes(pipeline)]

        if _structure(self.pipeline) != _structure(updated_pipeline):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.INVALID_ARGUMENT,
                message=(
                    'Updated pipeline should have the same structure as the '
                    'original.'))

        data_types_utils.set_metadata_value(
            self._execution.custom_properties[_UPDATED_PIPELINE_IR],
            _base64_encode(updated_pipeline))
        data_types_utils.set_metadata_value(
            self._execution.custom_properties[_UPDATE_OPTIONS],
            _base64_encode(update_options))
예제 #10
0
    def new(cls, mlmd_handle: metadata.Metadata,
            pipeline: pipeline_pb2.Pipeline) -> 'PipelineState':
        """Creates a `PipelineState` object for a new pipeline.

    No active pipeline with the same pipeline uid should exist for the call to
    be successful.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: IR of the pipeline.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: If a pipeline with same UID already exists.
    """
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        context = context_lib.register_context_if_not_exists(
            mlmd_handle,
            context_type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))

        executions = mlmd_handle.store.get_executions_by_context(context.id)
        if any(e for e in executions if execution_lib.is_execution_active(e)):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=f'Pipeline with uid {pipeline_uid} already active.')

        execution = execution_lib.prepare_execution(
            mlmd_handle,
            _ORCHESTRATOR_EXECUTION_TYPE,
            metadata_store_pb2.Execution.NEW,
            exec_properties={
                _PIPELINE_IR:
                base64.b64encode(pipeline.SerializeToString()).decode('utf-8')
            },
        )
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            data_types_utils.set_metadata_value(
                execution.custom_properties[_PIPELINE_RUN_ID],
                pipeline.runtime_spec.pipeline_run_id.field_value.string_value)

        execution = execution_lib.put_execution(mlmd_handle, execution,
                                                [context])
        record_state_change_time()

        return cls(mlmd_handle=mlmd_handle,
                   pipeline=pipeline,
                   execution_id=execution.id)
예제 #11
0
 def apply_pipeline_update(self) -> None:
     """Applies pipeline update that was previously initiated."""
     self._check_context()
     updated_pipeline_ir = _get_metadata_value(
         self._execution.custom_properties.get(_UPDATED_PIPELINE_IR))
     if not updated_pipeline_ir:
         raise status_lib.StatusNotOkError(
             code=status_lib.Code.INVALID_ARGUMENT,
             message='No updated pipeline IR to apply')
     data_types_utils.set_metadata_value(
         self._execution.properties[_PIPELINE_IR], updated_pipeline_ir)
     del self._execution.custom_properties[_UPDATED_PIPELINE_IR]
     del self._execution.custom_properties[_UPDATE_OPTIONS]
     self.pipeline = _base64_decode_pipeline(updated_pipeline_ir)
예제 #12
0
    def test_restart_node_cancelled_due_to_stopping(self):
        """Tests that a node previously cancelled due to stopping can be restarted."""
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        [stats_gen_task
         ] = self._generate_and_test(False,
                                     num_initial_executions=1,
                                     num_tasks_generated=1,
                                     num_new_executions=1,
                                     num_active_executions=1,
                                     ignore_update_node_state_tasks=True)
        node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline,
                                                       self._stats_gen)
        self.assertEqual(node_uid, stats_gen_task.node_uid)

        # Simulate stopping the node while it is under execution, which leads to
        # the node execution being cancelled.
        with self._mlmd_connection as m:
            with mlmd_state.mlmd_execution_atomic_op(
                    m, stats_gen_task.execution_id) as stats_gen_exec:
                stats_gen_exec.last_known_state = metadata_store_pb2.Execution.CANCELED
                data_types_utils.set_metadata_value(
                    stats_gen_exec.custom_properties[
                        constants.EXECUTION_ERROR_MSG_KEY], 'manually stopped')

        # Change state of node to STARTING.
        with self._mlmd_connection as m:
            pipeline_state = test_utils.get_or_create_pipeline_state(
                m, self._pipeline)
            with pipeline_state:
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STARTING)

        # New execution should be created for any previously canceled node when the
        # node state is STARTING.
        [update_node_state_task,
         stats_gen_task] = self._generate_and_test(False,
                                                   num_initial_executions=2,
                                                   num_tasks_generated=2,
                                                   num_new_executions=1,
                                                   num_active_executions=1)
        self.assertTrue(
            task_lib.is_update_node_state_task(update_node_state_task))
        self.assertEqual(node_uid, update_node_state_task.node_uid)
        self.assertEqual(pstate.NodeState.RUNNING,
                         update_node_state_task.state)
        self.assertEqual(node_uid, stats_gen_task.node_uid)
예제 #13
0
 def initiate_node_stop(self, node_uid: task_lib.NodeUid) -> None:
     """Updates pipeline state to signal that a node should be stopped."""
     if self.pipeline.execution_mode != pipeline_pb2.Pipeline.ASYNC:
         raise status_lib.StatusNotOkError(
             code=status_lib.Code.UNIMPLEMENTED,
             message='Node can be started only for async pipelines.')
     if not _is_node_uid_in_pipeline(node_uid, self.pipeline):
         raise status_lib.StatusNotOkError(
             code=status_lib.Code.INVALID_ARGUMENT,
             message=(
                 f'Node given by uid {node_uid} does not belong to pipeline '
                 f'given by uid {self.pipeline_uid}'))
     data_types_utils.set_metadata_value(
         self.execution.custom_properties[_node_stop_initiated_property(
             node_uid)], 1)
     self._commit = True
예제 #14
0
  def test_node_failed(self, use_task_queue):
    """Tests task generation when a node registers a failed execution."""
    otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1)

    def _ensure_node_services(unused_pipeline_state, node_id):
      self.assertEqual(self._example_gen.node_info.id, node_id)
      return service_jobs.ServiceStatus.SUCCESS

    self._mock_service_job_manager.ensure_node_services.side_effect = (
        _ensure_node_services)

    tasks, active_executions = self._generate_and_test(
        use_task_queue,
        num_initial_executions=1,
        num_tasks_generated=1,
        num_new_executions=1,
        num_active_executions=1)
    self.assertEqual(
        task_lib.NodeUid.from_pipeline_node(self._pipeline, self._stats_gen),
        tasks[0].node_uid)
    stats_gen_exec = active_executions[0]

    # Fail stats-gen execution.
    stats_gen_exec.last_known_state = metadata_store_pb2.Execution.FAILED
    data_types_utils.set_metadata_value(
        stats_gen_exec.custom_properties[constants.EXECUTION_ERROR_MSG_KEY],
        'foobar error')
    with self._mlmd_connection as m:
      m.store.put_executions([stats_gen_exec])
    if use_task_queue:
      task = self._task_queue.dequeue()
      self._task_queue.task_done(task)

    # Test generation of FinalizePipelineTask.
    tasks, _ = self._generate_and_test(
        True,
        num_initial_executions=2,
        num_tasks_generated=1,
        num_new_executions=0,
        num_active_executions=0)
    self.assertLen(tasks, 1)
    self.assertTrue(task_lib.is_finalize_pipeline_task(tasks[0]))
    self.assertEqual(status_lib.Code.ABORTED, tasks[0].status.code)
    self.assertRegexMatch(tasks[0].status.message, ['foobar error'])
예제 #15
0
    def test_node_failed(self, fail_fast):
        """Tests task generation when a node registers a failed execution."""
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        [stats_gen_task
         ] = self._generate_and_test(False,
                                     num_initial_executions=1,
                                     num_tasks_generated=1,
                                     num_new_executions=1,
                                     num_active_executions=1,
                                     ignore_update_node_state_tasks=True,
                                     fail_fast=fail_fast)
        self.assertEqual(
            task_lib.NodeUid.from_pipeline_node(self._pipeline,
                                                self._stats_gen),
            stats_gen_task.node_uid)
        with self._mlmd_connection as m:
            with mlmd_state.mlmd_execution_atomic_op(
                    m, stats_gen_task.execution_id) as stats_gen_exec:
                # Fail stats-gen execution.
                stats_gen_exec.last_known_state = metadata_store_pb2.Execution.FAILED
                data_types_utils.set_metadata_value(
                    stats_gen_exec.custom_properties[
                        constants.EXECUTION_ERROR_MSG_KEY], 'foobar error')

        # Test generation of FinalizePipelineTask.
        [update_node_state_task,
         finalize_task] = self._generate_and_test(True,
                                                  num_initial_executions=2,
                                                  num_tasks_generated=2,
                                                  num_new_executions=0,
                                                  num_active_executions=0,
                                                  fail_fast=fail_fast)
        self.assertTrue(
            task_lib.is_update_node_state_task(update_node_state_task))
        self.assertEqual('my_statistics_gen',
                         update_node_state_task.node_uid.node_id)
        self.assertEqual(pstate.NodeState.FAILED, update_node_state_task.state)
        self.assertRegexMatch(update_node_state_task.status.message,
                              ['foobar error'])
        self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task))
        self.assertEqual(status_lib.Code.ABORTED, finalize_task.status.code)
        self.assertRegexMatch(finalize_task.status.message, ['foobar error'])
예제 #16
0
def prepare_execution(
    metadata_handler: metadata.Metadata,
    execution_type: metadata_store_pb2.ExecutionType,
    state: metadata_store_pb2.Execution.State,
    exec_properties: Optional[Mapping[str, types.ExecPropertyTypes]] = None,
) -> metadata_store_pb2.Execution:
    """Creates an execution proto based on the information provided.

  Args:
    metadata_handler: A handler to access MLMD store.
    execution_type: A metadata_pb2.ExecutionType message describing the type of
      the execution.
    state: The state of the execution.
    exec_properties: Execution properties that need to be attached.

  Returns:
    A metadata_store_pb2.Execution message.
  """
    execution = metadata_store_pb2.Execution()
    execution.last_known_state = state
    execution.type_id = common_utils.register_type_if_not_exist(
        metadata_handler, execution_type).id

    exec_properties = exec_properties or {}
    # For every execution property, put it in execution.properties if its key is
    # in execution type schema. Otherwise, put it in execution.custom_properties.
    for k, v in exec_properties.items():
        value = pipeline_pb2.Value()
        value = data_types_utils.set_parameter_value(value, v)

        if value.HasField('schema'):
            # Stores schema in custom_properties for non-primitive types to allow
            # parsing in later stages.
            data_types_utils.set_metadata_value(
                execution.custom_properties[get_schema_key(k)],
                proto_utils.proto_to_json(value.schema))

        if (execution_type.properties.get(k) ==
                data_types_utils.get_metadata_value_type(v)):
            execution.properties[k].CopyFrom(value.field_value)
        else:
            execution.custom_properties[k].CopyFrom(value.field_value)
    logging.debug('Prepared EXECUTION:\n %s', execution)
    return execution
예제 #17
0
    def run(
        self, execution_info: portable_data_types.ExecutionInfo
    ) -> driver_output_pb2.DriverOutput:

        # Populate exec_properties
        result = driver_output_pb2.DriverOutput()
        # PipelineInfo and ComponentInfo are not actually used, two fake one are
        # created just to be compatible with the old API.
        pipeline_info = data_types.PipelineInfo('', '')
        component_info = data_types.ComponentInfo('', '', pipeline_info)
        exec_properties = self.resolve_exec_properties(
            execution_info.exec_properties, pipeline_info, component_info)
        for k, v in exec_properties.items():
            if v is not None:
                data_types_utils.set_metadata_value(result.exec_properties[k],
                                                    v)

        # Populate output_dict
        output_example = copy.deepcopy(execution_info.output_dict[
            standard_component_specs.EXAMPLES_KEY][0].mlmd_artifact)
        update_output_artifact(exec_properties, output_example)
        result.output_artifacts[standard_component_specs.
                                EXAMPLES_KEY].artifacts.append(output_example)
        return result
예제 #18
0
def _generate_context_proto(
        metadata_handler: metadata.Metadata,
        context_spec: pipeline_pb2.ContextSpec) -> metadata_store_pb2.Context:
    """Generates metadata_pb2.Context based on the ContextSpec message.

  Args:
    metadata_handler: A handler to access MLMD store.
    context_spec: A pipeline_pb2.ContextSpec message that instructs registering
      of a context.

  Returns:
    A metadata_store_pb2.Context message.

  Raises:
    RuntimeError: When actual property type does not match provided metadata
      type schema.
  """
    context_type = common_utils.register_type_if_not_exist(
        metadata_handler, context_spec.type)
    context_name = data_types_utils.get_value(context_spec.name)
    assert isinstance(context_name, Text), 'context name should be string.'
    result = metadata_store_pb2.Context(type_id=context_type.id,
                                        name=context_name)
    for k, v in context_spec.properties.items():
        if k in context_type.properties:
            actual_property_type = data_types_utils.get_metadata_value_type(v)
            if context_type.properties.get(k) == actual_property_type:
                data_types_utils.set_metadata_value(result.properties[k], v)
            else:
                raise RuntimeError(
                    'Property type %s different from provided metadata type property type %s for key %s'
                    %
                    (actual_property_type, context_type.properties.get(k), k))
        else:
            data_types_utils.set_metadata_value(result.custom_properties[k], v)
    return result
예제 #19
0
 def initiate_stop(self, status: status_lib.Status) -> None:
     """Updates pipeline state to signal stopping pipeline execution."""
     self._check_context()
     data_types_utils.set_metadata_value(
         self._execution.custom_properties[_STOP_INITIATED], 1)
     data_types_utils.set_metadata_value(
         self._execution.custom_properties[_PIPELINE_STATUS_CODE],
         int(status.code))
     if status.message:
         data_types_utils.set_metadata_value(
             self._execution.custom_properties[_PIPELINE_STATUS_MSG],
             status.message)
예제 #20
0
파일: compiler.py 프로젝트: jasonz1112/tfx
    def _compile_node(
        self,
        tfx_node: base_node.BaseNode,
        compile_context: _CompilerContext,
        deployment_config: pipeline_pb2.IntermediateDeploymentConfig,
        enable_cache: bool,
    ) -> pipeline_pb2.PipelineNode:
        """Compiles an individual TFX node into a PipelineNode proto.

    Args:
      tfx_node: A TFX node.
      compile_context: Resources needed to compile the node.
      deployment_config: Intermediate deployment config to set. Will include
        related specs for executors, drivers and platform specific configs.
      enable_cache: whether cache is enabled

    Raises:
      TypeError: When supplied tfx_node has values of invalid type.

    Returns:
      A PipelineNode proto that encodes information of the node.
    """
        node = pipeline_pb2.PipelineNode()

        # Step 1: Node info
        node.node_info.type.name = tfx_node.type
        node.node_info.id = tfx_node.id

        # Step 2: Node Context
        # Context for the pipeline, across pipeline runs.
        pipeline_context_pb = node.contexts.contexts.add()
        pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
        pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name

        # Context for the current pipeline run.
        if compile_context.is_sync_mode:
            pipeline_run_context_pb = node.contexts.contexts.add()
            pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
            compiler_utils.set_runtime_parameter_pb(
                pipeline_run_context_pb.name.runtime_parameter,
                constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

        # Context for the node, across pipeline runs.
        node_context_pb = node.contexts.contexts.add()
        node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
        node_context_pb.name.field_value.string_value = "{}.{}".format(
            compile_context.pipeline_info.pipeline_context_name,
            node.node_info.id)

        # Pre Step 3: Alter graph topology if needed.
        if compile_context.is_async_mode:
            tfx_node_inputs = self._compile_resolver_config(
                compile_context, tfx_node, node)
        else:
            tfx_node_inputs = tfx_node.inputs

        # Step 3: Node inputs
        for key, value in tfx_node_inputs.items():
            input_spec = node.inputs.inputs[key]
            channel = input_spec.channels.add()
            if value.producer_component_id:
                channel.producer_node_query.id = value.producer_component_id

                # Here we rely on pipeline.components to be topologically sorted.
                assert value.producer_component_id in compile_context.node_pbs, (
                    "producer component should have already been compiled.")
                producer_pb = compile_context.node_pbs[
                    value.producer_component_id]
                for producer_context in producer_pb.contexts.contexts:
                    if (not compiler_utils.is_resolver(tfx_node)
                            or producer_context.name.runtime_parameter.name !=
                            constants.PIPELINE_RUN_CONTEXT_TYPE_NAME):
                        context_query = channel.context_queries.add()
                        context_query.type.CopyFrom(producer_context.type)
                        context_query.name.CopyFrom(producer_context.name)
            else:
                # Caveat: portable core requires every channel to have at least one
                # Contex. But For cases like system nodes and producer-consumer
                # pipelines, a channel may not have contexts at all. In these cases,
                # we want to use the pipeline level context as the input channel
                # context.
                context_query = channel.context_queries.add()
                context_query.type.CopyFrom(pipeline_context_pb.type)
                context_query.name.CopyFrom(pipeline_context_pb.name)

            artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
            channel.artifact_query.type.CopyFrom(artifact_type)
            channel.artifact_query.type.ClearField("properties")

            if value.output_key:
                channel.output_key = value.output_key

            # TODO(b/158712886): Calculate min_count based on if inputs are optional.
            # min_count = 0 stands for optional input and 1 stands for required input.

        # Step 3.1: Special treatment for Resolver node.
        if compiler_utils.is_resolver(tfx_node):
            assert compile_context.is_sync_mode
            node.inputs.resolver_config.resolver_steps.extend(
                _convert_to_resolver_steps(tfx_node))

        # Step 4: Node outputs
        if isinstance(tfx_node, base_component.BaseComponent):
            for key, value in tfx_node.outputs.items():
                output_spec = node.outputs.outputs[key]
                artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
                output_spec.artifact_spec.type.CopyFrom(artifact_type)
                for prop_key, prop_value in value.additional_properties.items(
                ):
                    _check_property_value_type(prop_key, prop_value,
                                               output_spec.artifact_spec.type)
                    data_types_utils.set_metadata_value(
                        output_spec.artifact_spec.
                        additional_properties[prop_key].field_value,
                        prop_value)
                for prop_key, prop_value in value.additional_custom_properties.items(
                ):
                    data_types_utils.set_metadata_value(
                        output_spec.artifact_spec.
                        additional_custom_properties[prop_key].field_value,
                        prop_value)

        # TODO(b/170694459): Refactor special nodes as plugins.
        # Step 4.1: Special treament for Importer node
        if compiler_utils.is_importer(tfx_node):
            self._compile_importer_node_outputs(tfx_node, node)

        # Step 5: Node parameters
        if not compiler_utils.is_resolver(tfx_node):
            for key, value in tfx_node.exec_properties.items():
                if value is None:
                    continue
                # Ignore following two properties for a importer node, because they are
                # already attached to the artifacts produced by the importer node.
                if compiler_utils.is_importer(tfx_node) and (
                        key == importer.PROPERTIES_KEY
                        or key == importer.CUSTOM_PROPERTIES_KEY):
                    continue
                parameter_value = node.parameters.parameters[key]

                # Order matters, because runtime parameter can be in serialized string.
                if isinstance(value, data_types.RuntimeParameter):
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, value.name,
                        value.ptype, value.default)
                elif isinstance(value, str) and re.search(
                        data_types.RUNTIME_PARAMETER_PATTERN, value):
                    runtime_param = json.loads(value)
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, runtime_param.name,
                        runtime_param.ptype, runtime_param.default)
                else:
                    try:
                        data_types_utils.set_metadata_value(
                            parameter_value.field_value, value)
                    except ValueError:
                        raise ValueError(
                            "Component {} got unsupported parameter {} with type {}."
                            .format(tfx_node.id, key, type(value)))

        # Step 6: Executor spec and optional driver spec for components
        if isinstance(tfx_node, base_component.BaseComponent):
            executor_spec = tfx_node.executor_spec.encode(
                component_spec=tfx_node.spec)
            deployment_config.executor_specs[tfx_node.id].Pack(executor_spec)

            # TODO(b/163433174): Remove specialized logic once generalization of
            # driver spec is done.
            if tfx_node.driver_class != base_driver.BaseDriver:
                driver_class_path = "{}.{}".format(
                    tfx_node.driver_class.__module__,
                    tfx_node.driver_class.__name__)
                driver_spec = executable_spec_pb2.PythonClassExecutableSpec()
                driver_spec.class_path = driver_class_path
                deployment_config.custom_driver_specs[tfx_node.id].Pack(
                    driver_spec)

        # Step 7: Upstream/Downstream nodes
        # Note: the order of tfx_node.upstream_nodes is inconsistent from
        # run to run. We sort them so that compiler generates consistent results.
        # For ASYNC mode upstream/downstream node information is not set as
        # compiled IR graph topology can be different from that on pipeline
        # authoring time; for example ResolverNode is removed.
        if compile_context.is_sync_mode:
            node.upstream_nodes.extend(
                sorted(node.id for node in tfx_node.upstream_nodes))
            node.downstream_nodes.extend(
                sorted(node.id for node in tfx_node.downstream_nodes))

        # Step 8: Node execution options
        node.execution_options.caching_options.enable_cache = enable_cache

        # Step 9: Per-node platform config
        if isinstance(tfx_node, base_component.BaseComponent):
            tfx_component = cast(base_component.BaseComponent, tfx_node)
            if tfx_component.platform_config:
                deployment_config.node_level_platform_configs[
                    tfx_node.id].Pack(tfx_component.platform_config)

        return node
예제 #21
0
 def testSetMetadataValueUnsupportedType(self):
     pb = metadata_store_pb2.Value()
     with self.assertRaises(ValueError):
         data_types_utils.set_metadata_value(pb, {'a': 1})
예제 #22
0
 def testSetMetadataValueWithPrimitiveValue(self, value, expected_pb):
     pb = metadata_store_pb2.Value()
     data_types_utils.set_metadata_value(pb, value)
     self.assertEqual(pb, expected_pb)
예제 #23
0
 def set_mlmd_value(
         self, value: metadata_store_pb2.Value) -> metadata_store_pb2.Value:
     data_types_utils.set_metadata_value(value, json_utils.dumps(self))
     return value
예제 #24
0
    def new(
        cls,
        mlmd_handle: metadata.Metadata,
        pipeline: pipeline_pb2.Pipeline,
        pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None,
    ) -> 'PipelineState':
        """Creates a `PipelineState` object for a new pipeline.

    No active pipeline with the same pipeline uid should exist for the call to
    be successful.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: IR of the pipeline.
      pipeline_run_metadata: Pipeline run metadata.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: If a pipeline with same UID already exists.
    """
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        context = context_lib.register_context_if_not_exists(
            mlmd_handle,
            context_type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))

        executions = mlmd_handle.store.get_executions_by_context(context.id)
        if any(e for e in executions if execution_lib.is_execution_active(e)):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=f'Pipeline with uid {pipeline_uid} already active.')

        exec_properties = {_PIPELINE_IR: _base64_encode(pipeline)}
        if pipeline_run_metadata:
            exec_properties[_PIPELINE_RUN_METADATA] = json_utils.dumps(
                pipeline_run_metadata)

        execution = execution_lib.prepare_execution(
            mlmd_handle,
            _ORCHESTRATOR_EXECUTION_TYPE,
            metadata_store_pb2.Execution.NEW,
            exec_properties=exec_properties)
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            data_types_utils.set_metadata_value(
                execution.custom_properties[_PIPELINE_RUN_ID],
                pipeline.runtime_spec.pipeline_run_id.field_value.string_value)
            # Set the node state to COMPLETE for any nodes that are marked to be
            # skipped in a partial pipeline run.
            node_states_dict = {}
            for node in get_all_pipeline_nodes(pipeline):
                if node.execution_options.HasField('skip'):
                    logging.info('Node %s is skipped in this partial run.',
                                 node.node_info.id)
                    node_states_dict[node.node_info.id] = NodeState(
                        state=NodeState.COMPLETE)
            if node_states_dict:
                _save_node_states_dict(execution, node_states_dict)

        execution = execution_lib.put_execution(mlmd_handle, execution,
                                                [context])
        record_state_change_time()

        return cls(mlmd_handle=mlmd_handle,
                   pipeline=pipeline,
                   execution_id=execution.id)
예제 #25
0
 def initiate_stop(self):
     """Updates pipeline state to signal stopping pipeline execution."""
     data_types_utils.set_metadata_value(
         self.execution.custom_properties[_STOP_INITIATED], 1)
     self._commit = True
예제 #26
0
def _save_node_states_dict(pipeline_execution: metadata_store_pb2.Execution,
                           node_states: Dict[str, NodeState]) -> None:
    data_types_utils.set_metadata_value(
        pipeline_execution.custom_properties[_NODE_STATES],
        json_utils.dumps(node_states))