Пример #1
0
def prepare_execution(
    metadata_handler: metadata.Metadata,
    execution_type: metadata_store_pb2.ExecutionType,
    state: metadata_store_pb2.Execution.State,
    exec_properties: Optional[Mapping[Text, types.Property]] = None,
) -> metadata_store_pb2.Execution:
    """Creates an execution proto based on the information provided.

  Args:
    metadata_handler: A handler to access MLMD store.
    execution_type: A metadata_pb2.ExecutionType message describing the type of
      the execution.
    state: The state of the execution.
    exec_properties: Execution properties that need to be attached.

  Returns:
    A metadata_store_pb2.Execution message.
  """
    execution = metadata_store_pb2.Execution()
    execution.last_known_state = state
    execution.type_id = common_utils.register_type_if_not_exist(
        metadata_handler, execution_type).id

    exec_properties = exec_properties or {}
    # For every execution property, put it in execution.properties if its key is
    # in execution type schema. Otherwise, put it in execution.custom_properties.
    for k, v in exec_properties.items():
        if (execution_type.properties.get(k) ==
                common_utils.get_metadata_value_type(v)):
            common_utils.set_metadata_value(execution.properties[k], v)
        else:
            common_utils.set_metadata_value(execution.custom_properties[k], v)
    logging.debug('Prepared EXECUTION:\n %s', execution)
    return execution
Пример #2
0
Файл: driver.py Проект: lre/tfx
    def run(self, input_dict: Dict[Text, List[types.Artifact]],
            output_dict: Dict[Text, List[types.Artifact]],
            exec_properties: Dict[Text,
                                  Any]) -> driver_output_pb2.DriverOutput:

        # Populate exec_properties
        result = driver_output_pb2.DriverOutput()
        # PipelineInfo and ComponentInfo are not actually used, two fake one are
        # created just to be compatable with the old API.
        pipeline_info = data_types.PipelineInfo('', '')
        component_info = data_types.ComponentInfo('', '', pipeline_info)
        exec_properties = self.resolve_exec_properties(exec_properties,
                                                       pipeline_info,
                                                       component_info)
        for k, v in exec_properties.items():
            if v is not None:
                common_utils.set_metadata_value(result.exec_properties[k], v)

        # Populate output_dict
        output_example = copy.deepcopy(
            output_dict[utils.EXAMPLES_KEY][0].mlmd_artifact)
        _update_output_artifact(exec_properties, output_example)
        result.output_artifacts[utils.EXAMPLES_KEY].artifacts.append(
            output_example)
        return result
Пример #3
0
def substitute_runtime_parameter(
        msg: message.Message,
        parameter_bindings: Mapping[str, types.Property]) -> None:
    """Utility function to substitute runtime parameter placeholders with values.

  Args:
    msg: The original message to change. Only messages defined under
      pipeline_pb2 will be supported. Other types will result in no-op.
    parameter_bindings: A dict of parameter keys to parameter values that will
      be used to substitute the runtime parameter placeholder.

  Returns:
    None
  """
    if not isinstance(msg, message.Message):
        return

    # If the message is a pipeline_pb2.Value instance, try to find an substitute
    # with runtime parameter bindings.
    if isinstance(msg, pipeline_pb2.Value):
        value = cast(pipeline_pb2.Value, msg)
        which = value.WhichOneof('value')
        if which == 'runtime_parameter':
            real_value = _get_runtime_parameter_value(value.runtime_parameter,
                                                      parameter_bindings)
            if real_value is None:
                return
            value.Clear()
            common_utils.set_metadata_value(metadata_value=value.field_value,
                                            value=real_value)
        if which == 'structural_runtime_parameter':
            real_value = _get_structural_runtime_parameter_value(
                value.structural_runtime_parameter, parameter_bindings)
            if real_value is None:
                return
            value.Clear()
            common_utils.set_metadata_value(metadata_value=value.field_value,
                                            value=real_value)

        return

    # For other cases, recursively call into sub-messages if any.
    for field, sub_message in msg.ListFields():
        # No-op for non-message types.
        if field.type != descriptor.FieldDescriptor.TYPE_MESSAGE:
            continue
        # Evaluates every map values in a map.
        elif (field.message_type.has_options
              and field.message_type.GetOptions().map_entry):
            for key in sub_message:
                substitute_runtime_parameter(sub_message[key],
                                             parameter_bindings)
        # Evaluates every entry in a list.
        elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
            for element in sub_message:
                substitute_runtime_parameter(element, parameter_bindings)
        # Evaluates sub-message.
        else:
            substitute_runtime_parameter(sub_message, parameter_bindings)
Пример #4
0
 def testSetMetadataValueWithTfxValue(self):
     tfx_value = pipeline_pb2.Value()
     metadata_property = metadata_store_pb2.Value()
     text_format.Parse(
         """
     field_value {
         int_value: 1
     }""", tfx_value)
     common_utils.set_metadata_value(metadata_value=metadata_property,
                                     value=tfx_value)
     self.assertProtoEquals('int_value: 1', metadata_property)
Пример #5
0
 def testSetMetadataValueWithTfxValueFailed(self):
     tfx_value = pipeline_pb2.Value()
     metadata_property = metadata_store_pb2.Value()
     text_format.Parse(
         """
     runtime_parameter {
       name: 'rp'
     }""", tfx_value)
     with self.assertRaisesRegex(RuntimeError,
                                 'Expecting field_value but got'):
         common_utils.set_metadata_value(metadata_value=metadata_property,
                                         value=tfx_value)
Пример #6
0
def _generate_context_proto(
        metadata_handler: metadata.Metadata,
        context_spec: pipeline_pb2.ContextSpec) -> metadata_store_pb2.Context:
    """Generates metadata_pb2.Context based on the ContextSpec message.

  Args:
    metadata_handler: A handler to access MLMD store.
    context_spec: A pipeline_pb2.ContextSpec message that instructs registering
      of a context.

  Returns:
    A metadata_store_pb2.Context message.

  Raises:
    RuntimeError: When actual property type does not match provided metadata
      type schema.
  """
    context_type = common_utils.register_type_if_not_exist(
        metadata_handler, context_spec.type)
    context_name = common_utils.get_value(context_spec.name)
    assert isinstance(context_name, Text), 'context name should be string.'
    context = metadata_store_pb2.Context(type_id=context_type.id,
                                         name=context_name)
    for k, v in context_spec.properties.items():
        if k in context_type.properties:
            actual_property_type = common_utils.get_metadata_value_type(v)
            if context_type.properties.get(k) == actual_property_type:
                common_utils.set_metadata_value(context.properties[k], v)
            else:
                raise RuntimeError(
                    'Property type %s different from provided metadata type property type %s for key %s'
                    %
                    (actual_property_type, context_type.properties.get(k), k))
        else:
            common_utils.set_metadata_value(context.custom_properties[k], v)
    return context
Пример #7
0
 def testSetMetadataValueWithPrimitiveValue(self):
     metadata_property = metadata_store_pb2.Value()
     common_utils.set_metadata_value(metadata_value=metadata_property,
                                     value=1)
     self.assertProtoEquals('int_value: 1', metadata_property)
Пример #8
0
def _set_stop_initiated_property(
        execution: metadata_store_pb2.Execution) -> None:
    common_utils.set_metadata_value(
        execution.custom_properties[_STOP_INITIATED], 1)
Пример #9
0
 def initiate_stop(self):
     """Updates pipeline state to signal stopping pipeline execution."""
     common_utils.set_metadata_value(
         self.execution.custom_properties[_STOP_INITIATED], 1)
     self._commit = True