示例#1
0
 def testGetValueFailed(self):
     tfx_value = pipeline_pb2.Value()
     text_format.Parse(
         """
     runtime_parameter {
       name: 'rp'
     }""", tfx_value)
     with self.assertRaisesRegex(RuntimeError,
                                 'Expecting field_value but got'):
         data_types_utils.get_value(tfx_value)
示例#2
0
def get_executions(
        metadata_handler: metadata.Metadata,
        node: pipeline_pb2.PipelineNode) -> List[metadata_store_pb2.Execution]:
    """Returns all executions for the given pipeline node.

  This finds all executions having the same set of contexts as the pipeline
  node.

  Args:
    metadata_handler: A handler to access MLMD db.
    node: The pipeline node for which to obtain executions.

  Returns:
    List of executions for the given node in MLMD db.
  """
    if not node.contexts.contexts:
        return []
    # Get all the contexts associated with the node.
    contexts = []
    for i, context_spec in enumerate(node.contexts.contexts):
        context_type = context_spec.type.name
        context_name = data_types_utils.get_value(context_spec.name)
        contexts.append(
            f"(contexts_{i}.type = '{context_type}' AND contexts_{i}.name = '{context_name}')"
        )
    filter_query = ' AND '.join(contexts)
    return metadata_handler.store.get_executions(list_options=mlmd.ListOptions(
        filter_query=filter_query))
示例#3
0
def get_executions(
    metadata_handler: metadata.Metadata,
    node: pipeline_pb2.PipelineNode) -> List[metadata_store_pb2.Execution]:
  """Returns all executions for the given pipeline node.

  This finds all executions having the same set of contexts as the pipeline
  node.

  Args:
    metadata_handler: A handler to access MLMD db.
    node: The pipeline node for which to obtain executions.

  Returns:
    List of executions for the given node in MLMD db.
  """
  # Get all the contexts associated with the node.
  contexts = []
  for context_spec in node.contexts.contexts:
    context = metadata_handler.store.get_context_by_type_and_name(
        context_spec.type.name, data_types_utils.get_value(context_spec.name))
    if context is None:
      # If no context is registered, it's certain that there is no
      # associated execution for the node.
      return []
    contexts.append(context)
  return execution_lib.get_executions_associated_with_all_contexts(
      metadata_handler, contexts)
示例#4
0
 def testGetValue(self):
     tfx_value = pipeline_pb2.Value()
     text_format.Parse(
         """
     field_value {
       int_value: 1
     }""", tfx_value)
     self.assertEqual(data_types_utils.get_value(tfx_value), 1)
示例#5
0
def _resolve_single_channel(
    metadata_handler: metadata.Metadata,
    channel: pipeline_pb2.InputSpec.Channel) -> List[types.Artifact]:
  """Resolves input artifacts from a single channel."""

  artifact_type = channel.artifact_query.type
  output_key = channel.output_key or None

  contexts = []
  for context_query in channel.context_queries:
    context = metadata_handler.store.get_context_by_type_and_name(
        context_query.type.name, data_types_utils.get_value(context_query.name))
    if context:
      contexts.append(context)
  return get_qualified_artifacts(
      metadata_handler=metadata_handler,
      contexts=contexts,
      artifact_type=artifact_type,
      output_key=output_key)
示例#6
0
def _register_context_if_not_exist(
    metadata_handler: metadata.Metadata,
    context_spec: pipeline_pb2.ContextSpec,
) -> metadata_store_pb2.Context:
    """Registers a context if not exist, otherwise returns the existing one.

  Args:
    metadata_handler: A handler to access MLMD store.
    context_spec: A pipeline_pb2.ContextSpec message that instructs registering
      of a context.

  Returns:
    An MLMD context.
  """
    context_type_name = context_spec.type.name
    context_name = data_types_utils.get_value(context_spec.name)
    context = metadata_handler.store.get_context_by_type_and_name(
        type_name=context_type_name, context_name=context_name)
    if context is not None:
        return context

    logging.debug('Failed to get context of type %s and name %s',
                  context_type_name, context_name)
    # If Context is not found, try to register it.
    context = _generate_context_proto(metadata_handler=metadata_handler,
                                      context_spec=context_spec)
    try:
        [context_id] = metadata_handler.store.put_contexts([context])
        context.id = context_id
    # This might happen in cases we have parallel executions of nodes.
    except mlmd.errors.AlreadyExistsError:
        logging.debug('Context %s already exists.', context_name)
        context = metadata_handler.store.get_context_by_type_and_name(
            type_name=context_type_name, context_name=context_name)
        assert context is not None, (
            'Context is missing for %s while put_contexts '
            'reports that it existed.') % (context_name)

    logging.debug('ID of context %s is %s.', context_spec, context.id)
    return context
示例#7
0
def _generate_context_proto(
        metadata_handler: metadata.Metadata,
        context_spec: pipeline_pb2.ContextSpec) -> metadata_store_pb2.Context:
    """Generates metadata_pb2.Context based on the ContextSpec message.

  Args:
    metadata_handler: A handler to access MLMD store.
    context_spec: A pipeline_pb2.ContextSpec message that instructs registering
      of a context.

  Returns:
    A metadata_store_pb2.Context message.

  Raises:
    RuntimeError: When actual property type does not match provided metadata
      type schema.
  """
    context_type = common_utils.register_type_if_not_exist(
        metadata_handler, context_spec.type)
    context_name = data_types_utils.get_value(context_spec.name)
    assert isinstance(context_name, Text), 'context name should be string.'
    result = metadata_store_pb2.Context(type_id=context_type.id,
                                        name=context_name)
    for k, v in context_spec.properties.items():
        if k in context_type.properties:
            actual_property_type = data_types_utils.get_metadata_value_type(v)
            if context_type.properties.get(k) == actual_property_type:
                data_types_utils.set_metadata_value(result.properties[k], v)
            else:
                raise RuntimeError(
                    'Property type %s different from provided metadata type property type %s for key %s'
                    %
                    (actual_property_type, context_type.properties.get(k), k))
        else:
            data_types_utils.set_metadata_value(result.custom_properties[k], v)
    return result
示例#8
0
 def _as_dict(proto_map) -> Dict[str, types.Property]:
     return {
         k: data_types_utils.get_value(v)
         for k, v in proto_map.items()
     }