예제 #1
0
 def testGetValueFailed(self):
     tfx_value = pipeline_pb2.Value()
     text_format.Parse(
         """
     runtime_parameter {
       name: 'rp'
     }""", tfx_value)
     with self.assertRaisesRegex(RuntimeError,
                                 'Expecting field_value but got'):
         common_utils.get_value(tfx_value)
예제 #2
0
def get_executions(
        metadata_handler: metadata.Metadata,
        node: pipeline_pb2.PipelineNode) -> List[metadata_store_pb2.Execution]:
    """Returns all executions for the given pipeline node.

  This finds all executions having the same set of contexts as the pipeline
  node.

  Args:
    metadata_handler: A handler to access MLMD db.
    node: The pipeline node for which to obtain executions.

  Returns:
    List of executions for the given node in MLMD db.
  """
    # Get all the contexts associated with the node.
    contexts = []
    for context_spec in node.contexts.contexts:
        context = metadata_handler.store.get_context_by_type_and_name(
            context_spec.type.name, common_utils.get_value(context_spec.name))
        if context is None:
            # If no context is registered, it's certain that there is no
            # associated execution for the node.
            return []
        contexts.append(context)
    return execution_lib.get_executions_associated_with_all_contexts(
        metadata_handler, contexts)
예제 #3
0
def _extract_properties(
        execution: metadata_store_pb2.Execution) -> Dict[Text, types.Property]:
    result = {}
    for key, prop in itertools.chain(execution.properties.items(),
                                     execution.custom_properties.items()):
        result[key] = common_utils.get_value(prop)
    return result
예제 #4
0
def _register_context_if_not_exist(
    metadata_handler: metadata.Metadata,
    context_spec: pipeline_pb2.ContextSpec,
) -> metadata_store_pb2.Context:
    """Registers a context if not exist, otherwise returns the existing one.

  Args:
    metadata_handler: A handler to access MLMD store.
    context_spec: A pipeline_pb2.ContextSpec message that instructs registering
      of a context.

  Returns:
    An MLMD context.
  """
    context = _generate_context_proto(metadata_handler=metadata_handler,
                                      context_spec=context_spec)
    try:
        [context_id] = metadata_handler.store.put_contexts([context])
        context.id = context_id
    except mlmd.errors.AlreadyExistsError:
        context_name = common_utils.get_value(context_spec.name)
        logging.debug('Context %s already exists.', context_name)
        context = metadata_handler.store.get_context_by_type_and_name(
            type_name=context_spec.type.name, context_name=context_name)
        assert context is not None, (
            'Context is missing for %s while put_contexts '
            'reports that it existed.') % (context_spec.name)

    logging.debug('ID of context %s is %s.', context_spec, context.id)
    return context
예제 #5
0
 def testGetValue(self):
     tfx_value = pipeline_pb2.Value()
     text_format.Parse(
         """
     field_value {
       int_value: 1
     }""", tfx_value)
     self.assertEqual(common_utils.get_value(tfx_value), 1)
예제 #6
0
def _resolve_single_channel(
        metadata_handler: metadata.Metadata,
        channel: pipeline_pb2.InputSpec.Channel) -> List[types.Artifact]:
    """Resolves input artifacts from a single channel."""

    artifact_type = channel.artifact_query.type
    output_key = channel.output_key or None
    contexts = filter(None, [
        metadata_handler.store.get_context_by_type_and_name(
            context_query.type.name, common_utils.get_value(
                context_query.name))
        for context_query in channel.context_queries
    ])
    return get_qualified_artifacts(metadata_handler=metadata_handler,
                                   contexts=contexts,
                                   artifact_type=artifact_type,
                                   output_key=output_key)
예제 #7
0
def _resolve_single_channel(
    metadata_handler: metadata.Metadata,
    channel: pipeline_pb2.InputSpec.Channel) -> List[types.Artifact]:
  """Resolves input artifacts from a single channel."""

  artifact_type = channel.artifact_query.type
  output_key = channel.output_key or None
  # 1. filter(None, list) filters "false" value out from the list
  # 2. even if the filter() result is empty, its result is considered as "true"
  #    so turning it into a list explicitly.
  contexts = list(filter(None, [
      metadata_handler.store.get_context_by_type_and_name(
          context_query.type.name, common_utils.get_value(context_query.name))
      for context_query in channel.context_queries
  ]))
  return get_qualified_artifacts(
      metadata_handler=metadata_handler,
      contexts=contexts,
      artifact_type=artifact_type,
      output_key=output_key)
예제 #8
0
def _generate_context_proto(
        metadata_handler: metadata.Metadata,
        context_spec: pipeline_pb2.ContextSpec) -> metadata_store_pb2.Context:
    """Generates metadata_pb2.Context based on the ContextSpec message.

  Args:
    metadata_handler: A handler to access MLMD store.
    context_spec: A pipeline_pb2.ContextSpec message that instructs registering
      of a context.

  Returns:
    A metadata_store_pb2.Context message.

  Raises:
    RuntimeError: When actual property type does not match provided metadata
      type schema.
  """
    context_type = common_utils.register_type_if_not_exist(
        metadata_handler, context_spec.type)
    context_name = common_utils.get_value(context_spec.name)
    assert isinstance(context_name, Text), 'context name should be string.'
    context = metadata_store_pb2.Context(type_id=context_type.id,
                                         name=context_name)
    for k, v in context_spec.properties.items():
        if k in context_type.properties:
            actual_property_type = common_utils.get_metadata_value_type(v)
            if context_type.properties.get(k) == actual_property_type:
                common_utils.set_metadata_value(context.properties[k], v)
            else:
                raise RuntimeError(
                    'Property type %s different from provided metadata type property type %s for key %s'
                    %
                    (actual_property_type, context_type.properties.get(k), k))
        else:
            common_utils.set_metadata_value(context.custom_properties[k], v)
    return context