Example #1
0
def _run_resolver_strategy(
    input_dict: typing_utils.ArtifactMultiMap,
    *,
    strategy: resolver.ResolverStrategy,
    input_keys: Iterable[str],
    store: mlmd.MetadataStore,
) -> typing_utils.ArtifactMultiMap:
    """Runs a single ResolverStrategy with MLMD store."""
    if not typing_utils.is_artifact_multimap(input_dict):
        raise TypeError(f'Invalid argument type: {input_dict!r}. Must be '
                        'Mapping[str, Sequence[Artifact]].')
    valid_keys = input_keys or set(input_dict.keys())
    valid_inputs = {
        key: list(value)
        for key, value in input_dict.items() if key in valid_keys
    }
    bypassed_inputs = {
        key: list(value)
        for key, value in input_dict.items() if key not in valid_keys
    }
    result = strategy.resolve_artifacts(store, valid_inputs)
    if result is None:
        raise exceptions.InputResolutionError(f'{strategy} returned None.')
    else:
        result.update(bypassed_inputs)
        return result
Example #2
0
 def is_acceptable(self, value: Any) -> bool:
     """Check the value is instance of the data type."""
     if self == self.ARTIFACT_LIST:
         return typing_utils.is_homogeneous_artifact_list(value)
     elif self == self.ARTIFACT_MULTIMAP:
         return typing_utils.is_artifact_multimap(value)
     elif self == self.ARTIFACT_MULTIMAP_LIST:
         return typing_utils.is_list_of_artifact_multimap(value)
     raise NotImplementedError(f'Cannot check type for {self}.')
Example #3
0
def resolve_input_artifacts_v2(
    *,
    pipeline_node: pipeline_pb2.PipelineNode,
    metadata_handler: metadata.Metadata,
) -> Union[Trigger, Skip]:
  """Resolve input artifacts according to a pipeline node IR definition.

  Input artifacts are resolved in the following steps:

  1. An initial input dict (Mapping[str, Sequence[Artifact]]) is fetched from
     the input channel definitions (in NodeInputs.inputs.channels).
  2. Optionally input resolution logic is performed if specified (in
     NodeInputs.resolver_config).
  3. Filters input map with enough number of artifacts as specified in
     NodeInputs.inputs.min_count.

  Args:
    pipeline_node: Current PipelineNode on which input resolution is running.
    metadata_handler: MetadataHandler instance for MLMD access.

  Raises:
    InputResolutionError: If input resolution went wrong.

  Returns:
    Trigger: a non-empty list of input dicts. All resolved input dicts should be
        executed.
    Skip: an empty list. Should effectively skip the current component
        execution.
  """
  node_inputs = pipeline_node.inputs
  initial_dict = _resolve_initial_dict(metadata_handler, node_inputs)
  try:
    resolved = processor.run_resolver_steps(
        initial_dict,
        resolver_steps=node_inputs.resolver_config.resolver_steps,
        store=metadata_handler.store)
  except exceptions.SkipSignal:
    return Skip()
  except exceptions.InputResolutionError:
    raise
  except Exception as e:
    raise exceptions.InputResolutionError(
        f'Error occurred during input resolution: {str(e)}.') from e

  if typing_utils.is_artifact_multimap(resolved):
    resolved = [resolved]
  if not typing_utils.is_list_of_artifact_multimap(resolved):
    raise exceptions.FailedPreconditionError(
        'Invalid input resolution result; expected Sequence[ArtifactMultiMap] '
        f'type but got {resolved}.')
  resolved = [d for d in resolved if _is_sufficient(d, node_inputs)]
  if not resolved:
    raise exceptions.FailedPreconditionError('No valid inputs.')
  return Trigger(resolved)
Example #4
0
def resolve_input_artifacts(
    metadata_handler: metadata.Metadata, node_inputs: pipeline_pb2.NodeInputs
) -> Optional[typing_utils.ArtifactMultiMap]:
  """Resolves input artifacts of a pipeline node.

  Args:
    metadata_handler: A metadata handler to access MLMD store.
    node_inputs: A pipeline_pb2.NodeInputs message that instructs artifact
      resolution for a pipeline node.

  Returns:
    If `min_count` for every input is met, returns a Dict[str, List[Artifact]].
    Otherwise, return None.
  """
  initial_dict = _resolve_initial_dict(metadata_handler, node_inputs)
  if not _is_sufficient(initial_dict, node_inputs):
    min_counts = {key: input_spec.min_count
                  for key, input_spec in node_inputs.inputs.items()}
    logging.warning('Resolved inputs should have %r artifacts, but got %r.',
                    min_counts, initial_dict)
    return None

  try:
    result = processor.run_resolver_steps(
        initial_dict,
        resolver_steps=node_inputs.resolver_config.resolver_steps,
        store=metadata_handler.store)
  except exceptions.InputResolutionError:
    # If ResolverStrategy has returned None in the middle, InputResolutionError
    # is raised. Legacy input resolution has returned None in this case.
    return None
  except exceptions.SkipSignal:
    # SkipSignal is not fully representable return value in legacy input
    # resolution, but None is the best effort.
    return None

  if typing_utils.is_list_of_artifact_multimap(result):
    result = cast(Sequence[typing_utils.ArtifactMultiMap], result)
    if len(result) != 1:
      raise ValueError(
          'Invalid number of resolved inputs; expected 1 but got '
          f'{len(result)}: {result}')
    return result[0]
  elif typing_utils.is_artifact_multimap(result):
    return cast(typing_utils.ArtifactMultiMap, result)
  else:
    raise TypeError(f'Invalid input resolution result: {result}. Should be '
                    'Mapping[str, Sequence[Artifact]].')
Example #5
0
 def no(value: Any):
     self.assertFalse(typing_utils.is_artifact_multimap(value))
Example #6
0
 def yes(value: Any):
     self.assertTrue(typing_utils.is_artifact_multimap(value))