Example #1
0
    def test_get_parameter_type(self, given_type, expected_type):
        self.assertEqual(expected_type,
                         type_utils.get_parameter_type(given_type))

        # Test get parameter by Python type.
        self.assertEqual(pb.PrimitiveType.INT,
                         type_utils.get_parameter_type(int))
Example #2
0
def build_component_spec_from_structure(
    component_spec: structures.ComponentSpec,
) -> pipeline_spec_pb2.ComponentSpec:
  """Builds an IR ComponentSpec instance from structures.ComponentSpec.

  Args:
    component_spec: The structure component spec.

  Returns:
    An instance of IR ComponentSpec.
  """
  result = pipeline_spec_pb2.ComponentSpec()
  result.executor_label = dsl_utils.sanitize_executor_label(component_spec.name)

  for input_spec in component_spec.inputs or []:
    if type_utils.is_parameter_type(input_spec.type):
      result.input_definitions.parameters[
          input_spec.name].type = type_utils.get_parameter_type(input_spec.type)
    else:
      result.input_definitions.artifacts[
          input_spec.name].artifact_type.instance_schema = (
              type_utils.get_artifact_type_schema(input_spec.type))

  for output_spec in component_spec.outputs or []:
    if type_utils.is_parameter_type(output_spec.type):
      result.output_definitions.parameters[
          output_spec.name].type = type_utils.get_parameter_type(
              output_spec.type)
    else:
      result.output_definitions.artifacts[
          output_spec.name].artifact_type.instance_schema = (
              type_utils.get_artifact_type_schema(output_spec.type))

  return result
Example #3
0
def build_component_inputs_spec(
    component_spec: pipeline_spec_pb2.ComponentSpec,
    pipeline_params: List[_pipeline_param.PipelineParam],
    is_root_component: bool,
) -> None:
    """Builds component inputs spec from pipeline params.

  Args:
    component_spec: The component spec to fill in its inputs spec.
    pipeline_params: The list of pipeline params.
    is_root_component: Whether the component is the root.
  """
    for param in pipeline_params:
        input_name = (param.full_name if is_root_component else
                      additional_input_name_for_pipelineparam(param))

        if type_utils.is_parameter_type(param.param_type):
            component_spec.input_definitions.parameters[
                input_name].type = type_utils.get_parameter_type(
                    param.param_type)
        elif input_name not in getattr(component_spec.input_definitions,
                                       'parameters', []):
            component_spec.input_definitions.artifacts[
                input_name].artifact_type.CopyFrom(
                    type_utils.get_artifact_type_schema_message(
                        param.param_type))
Example #4
0
def build_component_spec_from_structure(
    component_spec: structures.ComponentSpec,
    executor_label: str,
    actual_inputs: List[str],
) -> pipeline_spec_pb2.ComponentSpec:
    """Builds an IR ComponentSpec instance from structures.ComponentSpec.

  Args:
    component_spec: The structure component spec.
    executor_label: The executor label.
    actual_inputs: The actual arugments passed to the task. This is used as a
      short term workaround to support optional inputs in component spec IR.

  Returns:
    An instance of IR ComponentSpec.
  """
    result = pipeline_spec_pb2.ComponentSpec()
    result.executor_label = executor_label

    for input_spec in component_spec.inputs or []:
        # skip inputs not present
        if input_spec.name not in actual_inputs:
            continue
        if type_utils.is_parameter_type(input_spec.type):
            result.input_definitions.parameters[
                input_spec.name].type = type_utils.get_parameter_type(
                    input_spec.type)
        else:
            result.input_definitions.artifacts[
                input_spec.name].artifact_type.CopyFrom(
                    type_utils.get_artifact_type_schema_message(
                        input_spec.type))

    for output_spec in component_spec.outputs or []:
        if type_utils.is_parameter_type(output_spec.type):
            result.output_definitions.parameters[
                output_spec.name].type = type_utils.get_parameter_type(
                    output_spec.type)
        else:
            result.output_definitions.artifacts[
                output_spec.name].artifact_type.CopyFrom(
                    type_utils.get_artifact_type_schema_message(
                        output_spec.type))

    return result
def build_component_outputs_spec(
    component_spec: pipeline_spec_pb2.ComponentSpec,
    pipeline_params: List[_pipeline_param.PipelineParam],
) -> None:
  """Builds component outputs spec from pipeline params.

  Args:
    component_spec: The component spec to fill in its outputs spec.
    pipeline_params: The list of pipeline params.
  """
  for param in pipeline_params or []:
    output_name = param.full_name
    if type_utils.is_parameter_type(param.param_type):
      component_spec.output_definitions.parameters[
          output_name].type = type_utils.get_parameter_type(param.param_type)
    elif output_name not in getattr(component_spec.output_definitions,
                                    'parameters', []):
      component_spec.output_definitions.artifacts[
          output_name].artifact_type.CopyFrom(
              type_utils.get_artifact_type_schema(param.param_type))
Example #6
0
def build_component_inputs_spec(
    component_spec: pipeline_spec_pb2.ComponentSpec,
    pipeline_params: List[_pipeline_param.PipelineParam],
) -> None:
    """Builds component inputs spec from pipeline params.

  Args:
    component_spec: The component spec to fill in its inputs spec.
    pipeline_params: The list of pipeline params.
  """
    for param in pipeline_params:
        input_name = param.full_name

        if type_utils.is_parameter_type(param.param_type):
            component_spec.input_definitions.parameters[
                input_name].type = type_utils.get_parameter_type(
                    param.param_type)
        else:
            component_spec.input_definitions.artifacts[
                input_name].artifact_type.instance_schema = (
                    type_utils.get_artifact_type_schema(param.param_type))
Example #7
0
    def _get_value(
            param: _pipeline_param.PipelineParam) -> pipeline_spec_pb2.Value:
        assert param.value is not None, 'None values should be filterd out.'

        result = pipeline_spec_pb2.Value()
        # TODO(chensun): remove defaulting to 'String' for None param_type once we
        # fix importer behavior.
        param_type = type_utils.get_parameter_type(param.param_type
                                                   or 'String')
        if param_type == pipeline_spec_pb2.PrimitiveType.INT:
            result.int_value = int(param.value)
        elif param_type == pipeline_spec_pb2.PrimitiveType.DOUBLE:
            result.double_value = float(param.value)
        elif param_type == pipeline_spec_pb2.PrimitiveType.STRING:
            result.string_value = str(param.value)
        else:
            # For every other type, defaults to 'String'.
            # TODO(chensun): remove this default behavior once we migrate from
            # `pipeline_spec_pb2.Value` to `protobuf.Value`.
            result.string_value = str(param.value)

        return result
Example #8
0
def _get_custom_job_op(
    task_name: str,
    job_spec: Dict[str, Any],
    input_artifacts: Optional[Dict[str, dsl.PipelineParam]] = None,
    input_parameters: Optional[Dict[str, _ValueOrPipelineParam]] = None,
    output_artifacts: Optional[Dict[str, Type[artifact.Artifact]]] = None,
    output_parameters: Optional[Dict[str, Any]] = None,
) -> AiPlatformCustomJobOp:
  """Gets an AiPlatformCustomJobOp from job spec and I/O definition."""
  pipeline_task_spec = pipeline_spec_pb2.PipelineTaskSpec()
  pipeline_component_spec = pipeline_spec_pb2.ComponentSpec()

  pipeline_task_spec.task_info.CopyFrom(
      pipeline_spec_pb2.PipelineTaskInfo(name=dsl_utils.sanitize_task_name(task_name)))

  # Iterate through the inputs/outputs declaration to get pipeline component
  # spec.
  for input_name, param in input_parameters.items():
    if isinstance(param, dsl.PipelineParam):
      pipeline_component_spec.input_definitions.parameters[
        input_name].type = type_utils.get_parameter_type(param.param_type)
    else:
      pipeline_component_spec.input_definitions.parameters[
        input_name].type = type_utils.get_parameter_type(type(param))

  for input_name, art in input_artifacts.items():
    if not isinstance(art, dsl.PipelineParam):
      raise RuntimeError(
          'Get unresolved input artifact for input %s. Input '
          'artifacts must be connected to a producer task.' % input_name)
    pipeline_component_spec.input_definitions.artifacts[
      input_name].artifact_type.CopyFrom(
        type_utils.get_artifact_type_schema_message(art.param_type))

  for output_name, param_type in output_parameters.items():
    pipeline_component_spec.output_definitions.parameters[
      output_name].type = type_utils.get_parameter_type(param_type)

  for output_name, artifact_type in output_artifacts.items():
    pipeline_component_spec.output_definitions.artifacts[
      output_name].artifact_type.CopyFrom(artifact_type.get_ir_type())

  pipeline_component_spec.executor_label = dsl_utils.sanitize_executor_label(
      task_name)

  # Iterate through the inputs/outputs specs to get pipeline task spec.
  for input_name, param in input_parameters.items():
    if isinstance(param, dsl.PipelineParam) and param.op_name:
      # If the param has a valid op_name, this should be a pipeline parameter
      # produced by an upstream task.
      pipeline_task_spec.inputs.parameters[input_name].CopyFrom(
          pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec(
              task_output_parameter=pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec.TaskOutputParameterSpec(
                  producer_task=dsl_utils.sanitize_task_name(param.op_name),
                  output_parameter_key=param.name
              )))
    elif isinstance(param, dsl.PipelineParam) and not param.op_name:
      # If a valid op_name is missing, this should be a pipeline parameter.
      pipeline_task_spec.inputs.parameters[input_name].CopyFrom(
          pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec(
              component_input_parameter=param.name))
    else:
      # If this is not a pipeline param, then it should be a value.
      pipeline_task_spec.inputs.parameters[input_name].CopyFrom(
          pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec(
              runtime_value=pipeline_spec_pb2.ValueOrRuntimeParameter(
                  constant_value=dsl_utils.get_value(param))))

  for input_name, art in input_artifacts.items():
    if art.op_name:
      # If the param has a valid op_name, this should be an artifact produced
      # by an upstream task.
      pipeline_task_spec.inputs.artifacts[input_name].CopyFrom(
          pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec(
              task_output_artifact=pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec.TaskOutputArtifactSpec(
                  producer_task=dsl_utils.sanitize_task_name(art.op_name),
                  output_artifact_key=art.name)))
    else:
      # Otherwise, this should be from the input of the subdag.
      pipeline_task_spec.inputs.artifacts[input_name].CopyFrom(
          pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec(
              component_input_artifact=art.name
          ))

  # TODO: Add task dependencies/trigger policies/caching/iterator
  pipeline_task_spec.component_ref.name = dsl_utils.sanitize_component_name(
      task_name)

  # Construct dummy I/O declaration for the op.
  # TODO: resolve name conflict instead of raising errors.
  dummy_outputs = collections.OrderedDict()
  for output_name, _ in output_artifacts.items():
    dummy_outputs[output_name] = _DUMMY_PATH

  for output_name, _ in output_parameters.items():
    if output_name in dummy_outputs:
      raise KeyError('Got name collision for output key %s. Consider renaming '
                     'either output parameters or output '
                     'artifacts.' % output_name)
    dummy_outputs[output_name] = _DUMMY_PATH

  dummy_inputs = collections.OrderedDict()
  for input_name, art in input_artifacts.items():
    dummy_inputs[input_name] = _DUMMY_PATH
  for input_name, param in input_parameters.items():
    if input_name in dummy_inputs:
      raise KeyError('Got name collision for input key %s. Consider renaming '
                     'either input parameters or input '
                     'artifacts.' % input_name)
    dummy_inputs[input_name] = _DUMMY_PATH

  # Construct the AIP (Unified) custom job op.
  return AiPlatformCustomJobOp(
      name=task_name,
      custom_job_spec=job_spec,
      component_spec=pipeline_component_spec,
      task_spec=pipeline_task_spec,
      task_inputs=[
          dsl.InputArgumentPath(
              argument=dummy_inputs[input_name],
              input=input_name,
              path=path,
          ) for input_name, path in dummy_inputs.items()
      ],
      task_outputs=dummy_outputs
  )
Example #9
0
    def test_get_parameter_type(self):
        # Test get parameter type by name.
        self.assertEqual(pb.PrimitiveType.INT,
                         type_utils.get_parameter_type('Int'))
        self.assertEqual(pb.PrimitiveType.INT,
                         type_utils.get_parameter_type('Integer'))
        self.assertEqual(pb.PrimitiveType.DOUBLE,
                         type_utils.get_parameter_type('Double'))
        self.assertEqual(pb.PrimitiveType.DOUBLE,
                         type_utils.get_parameter_type('Float'))
        self.assertEqual(pb.PrimitiveType.STRING,
                         type_utils.get_parameter_type('String'))
        self.assertEqual(pb.PrimitiveType.STRING,
                         type_utils.get_parameter_type('Str'))

        # Test get parameter by Python type.
        self.assertEqual(pb.PrimitiveType.INT,
                         type_utils.get_parameter_type(int))
        self.assertEqual(pb.PrimitiveType.DOUBLE,
                         type_utils.get_parameter_type(float))
        self.assertEqual(pb.PrimitiveType.STRING,
                         type_utils.get_parameter_type(str))

        with self.assertRaises(AttributeError):
            type_utils.get_parameter_type_schema(None)

        with self.assertRaisesRegex(TypeError, 'Got illegal parameter type.'):
            type_utils.get_parameter_type(bool)