示例#1
0
 def testRunResolverSteps_NoneRaisesSignal(self):
     config = pipeline_pb2.ResolverConfig()
     text_format.Parse(
         """
 resolver_steps {
   class_path: "tfx.orchestration.portable.input_resolution.processor_test.NoneStrategy"
 }
 """, config)
     with self.assertRaises(exceptions.InputResolutionError):
         processor.run_resolver_steps(self._input_dict,
                                      resolver_steps=config.resolver_steps,
                                      store=self._store)
示例#2
0
def resolve_input_artifacts_v2(
    *,
    pipeline_node: pipeline_pb2.PipelineNode,
    metadata_handler: metadata.Metadata,
) -> Union[Trigger, Skip]:
  """Resolve input artifacts according to a pipeline node IR definition.

  Input artifacts are resolved in the following steps:

  1. An initial input dict (Mapping[str, Sequence[Artifact]]) is fetched from
     the input channel definitions (in NodeInputs.inputs.channels).
  2. Optionally input resolution logic is performed if specified (in
     NodeInputs.resolver_config).
  3. Filters input map with enough number of artifacts as specified in
     NodeInputs.inputs.min_count.

  Args:
    pipeline_node: Current PipelineNode on which input resolution is running.
    metadata_handler: MetadataHandler instance for MLMD access.

  Raises:
    InputResolutionError: If input resolution went wrong.

  Returns:
    Trigger: a non-empty list of input dicts. All resolved input dicts should be
        executed.
    Skip: an empty list. Should effectively skip the current component
        execution.
  """
  node_inputs = pipeline_node.inputs
  initial_dict = _resolve_initial_dict(metadata_handler, node_inputs)
  try:
    resolved = processor.run_resolver_steps(
        initial_dict,
        resolver_steps=node_inputs.resolver_config.resolver_steps,
        store=metadata_handler.store)
  except exceptions.SkipSignal:
    return Skip()
  except exceptions.InputResolutionError:
    raise
  except Exception as e:
    raise exceptions.InputResolutionError(
        f'Error occurred during input resolution: {str(e)}.') from e

  if typing_utils.is_artifact_multimap(resolved):
    resolved = [resolved]
  if not typing_utils.is_list_of_artifact_multimap(resolved):
    raise exceptions.FailedPreconditionError(
        'Invalid input resolution result; expected Sequence[ArtifactMultiMap] '
        f'type but got {resolved}.')
  resolved = [d for d in resolved if _is_sufficient(d, node_inputs)]
  if not resolved:
    raise exceptions.FailedPreconditionError('No valid inputs.')
  return Trigger(resolved)
示例#3
0
def resolve_input_artifacts(
    metadata_handler: metadata.Metadata, node_inputs: pipeline_pb2.NodeInputs
) -> Optional[typing_utils.ArtifactMultiMap]:
  """Resolves input artifacts of a pipeline node.

  Args:
    metadata_handler: A metadata handler to access MLMD store.
    node_inputs: A pipeline_pb2.NodeInputs message that instructs artifact
      resolution for a pipeline node.

  Returns:
    If `min_count` for every input is met, returns a Dict[str, List[Artifact]].
    Otherwise, return None.
  """
  initial_dict = _resolve_initial_dict(metadata_handler, node_inputs)
  if not _is_sufficient(initial_dict, node_inputs):
    min_counts = {key: input_spec.min_count
                  for key, input_spec in node_inputs.inputs.items()}
    logging.warning('Resolved inputs should have %r artifacts, but got %r.',
                    min_counts, initial_dict)
    return None

  try:
    result = processor.run_resolver_steps(
        initial_dict,
        resolver_steps=node_inputs.resolver_config.resolver_steps,
        store=metadata_handler.store)
  except exceptions.InputResolutionError:
    # If ResolverStrategy has returned None in the middle, InputResolutionError
    # is raised. Legacy input resolution has returned None in this case.
    return None
  except exceptions.SkipSignal:
    # SkipSignal is not fully representable return value in legacy input
    # resolution, but None is the best effort.
    return None

  if typing_utils.is_list_of_artifact_multimap(result):
    result = cast(Sequence[typing_utils.ArtifactMultiMap], result)
    if len(result) != 1:
      raise ValueError(
          'Invalid number of resolved inputs; expected 1 but got '
          f'{len(result)}: {result}')
    return result[0]
  elif typing_utils.is_artifact_multimap(result):
    return cast(typing_utils.ArtifactMultiMap, result)
  else:
    raise TypeError(f'Invalid input resolution result: {result}. Should be '
                    'Mapping[str, Sequence[Artifact]].')
示例#4
0
    def testRunResolverSteps_ResolverOp_IgnoresInputKeys(self):
        config = pipeline_pb2.ResolverConfig()
        text_format.Parse(
            r"""
    resolver_steps {
      class_path: "tfx.orchestration.portable.input_resolution.processor_test.RepeatOp"
      config_json: "{\"num\": 2}"
      input_keys: ["examples"]
    }
    """, config)
        result = processor.run_resolver_steps(
            self._input_dict,
            resolver_steps=config.resolver_steps,
            store=self._store)

        self.assertLen(result['examples'], 2)
        self.assertLen(result['model'], 2)
示例#5
0
    def testRunResolverSteps_ResolverStrategy(self):
        config = pipeline_pb2.ResolverConfig()
        text_format.Parse(
            r"""
    resolver_steps {
      class_path: "tfx.orchestration.portable.input_resolution.processor_test.RepeatStrategy"
      config_json: "{\"num\": 2}"
    }
    """, config)
        result = processor.run_resolver_steps(
            self._input_dict,
            resolver_steps=config.resolver_steps,
            store=self._store)

        strategy = RepeatStrategy.last_created
        self.assertIs(strategy.call_history[0][0], self._store)
        self.assertLen(result['examples'], 2)
        self.assertLen(result['model'], 2)
示例#6
0
    def testRunResolverSteps_ResolverOp(self):
        config = pipeline_pb2.ResolverConfig()
        text_format.Parse(
            r"""
    resolver_steps {
      class_path: "tfx.orchestration.portable.input_resolution.processor_test.RepeatOp"
      config_json: "{\"num\": 2}"
    }
    """, config)
        result = processor.run_resolver_steps(
            self._input_dict,
            resolver_steps=config.resolver_steps,
            store=self._store)

        op = RepeatOp.last_created
        self.assertIs(op.context.store, self._store)
        self.assertEqual(op.num, 2)
        self.assertLen(result['examples'], 2)
        self.assertLen(result['model'], 2)