Esempio n. 1
0
 def test_get_ir_value(self):
   self.assertDictEqual(
       json_format.MessageToDict(pipeline_spec_pb2.Value(int_value=42)),
       json_format.MessageToDict(dsl_utils.get_value(42)))
   self.assertDictEqual(
       json_format.MessageToDict(pipeline_spec_pb2.Value(double_value=12.2)),
       json_format.MessageToDict(dsl_utils.get_value(12.2)))
   self.assertDictEqual(
       json_format.MessageToDict(
           pipeline_spec_pb2.Value(string_value='hello world')),
       json_format.MessageToDict(dsl_utils.get_value('hello world')))
   with self.assertRaisesRegex(TypeError, 'Got unexpected type'):
     dsl_utils.get_value(_DummyClass())
Esempio n. 2
0
def _get_custom_job_op(
    task_name: str,
    job_spec: Dict[str, Any],
    input_artifacts: Optional[Dict[str, dsl.PipelineParam]] = None,
    input_parameters: Optional[Dict[str, _ValueOrPipelineParam]] = None,
    output_artifacts: Optional[Dict[str, Type[artifact.Artifact]]] = None,
    output_parameters: Optional[Dict[str, Any]] = None,
) -> AiPlatformCustomJobOp:
  """Gets an AiPlatformCustomJobOp from job spec and I/O definition."""
  pipeline_task_spec = pipeline_spec_pb2.PipelineTaskSpec()
  pipeline_component_spec = pipeline_spec_pb2.ComponentSpec()

  pipeline_task_spec.task_info.CopyFrom(
      pipeline_spec_pb2.PipelineTaskInfo(name=dsl_utils.sanitize_task_name(task_name)))

  # Iterate through the inputs/outputs declaration to get pipeline component
  # spec.
  for input_name, param in input_parameters.items():
    if isinstance(param, dsl.PipelineParam):
      pipeline_component_spec.input_definitions.parameters[
        input_name].type = type_utils.get_parameter_type(param.param_type)
    else:
      pipeline_component_spec.input_definitions.parameters[
        input_name].type = type_utils.get_parameter_type(type(param))

  for input_name, art in input_artifacts.items():
    if not isinstance(art, dsl.PipelineParam):
      raise RuntimeError(
          'Get unresolved input artifact for input %s. Input '
          'artifacts must be connected to a producer task.' % input_name)
    pipeline_component_spec.input_definitions.artifacts[
      input_name].artifact_type.CopyFrom(
        type_utils.get_artifact_type_schema_message(art.param_type))

  for output_name, param_type in output_parameters.items():
    pipeline_component_spec.output_definitions.parameters[
      output_name].type = type_utils.get_parameter_type(param_type)

  for output_name, artifact_type in output_artifacts.items():
    pipeline_component_spec.output_definitions.artifacts[
      output_name].artifact_type.CopyFrom(artifact_type.get_ir_type())

  pipeline_component_spec.executor_label = dsl_utils.sanitize_executor_label(
      task_name)

  # Iterate through the inputs/outputs specs to get pipeline task spec.
  for input_name, param in input_parameters.items():
    if isinstance(param, dsl.PipelineParam) and param.op_name:
      # If the param has a valid op_name, this should be a pipeline parameter
      # produced by an upstream task.
      pipeline_task_spec.inputs.parameters[input_name].CopyFrom(
          pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec(
              task_output_parameter=pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec.TaskOutputParameterSpec(
                  producer_task=dsl_utils.sanitize_task_name(param.op_name),
                  output_parameter_key=param.name
              )))
    elif isinstance(param, dsl.PipelineParam) and not param.op_name:
      # If a valid op_name is missing, this should be a pipeline parameter.
      pipeline_task_spec.inputs.parameters[input_name].CopyFrom(
          pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec(
              component_input_parameter=param.name))
    else:
      # If this is not a pipeline param, then it should be a value.
      pipeline_task_spec.inputs.parameters[input_name].CopyFrom(
          pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec(
              runtime_value=pipeline_spec_pb2.ValueOrRuntimeParameter(
                  constant_value=dsl_utils.get_value(param))))

  for input_name, art in input_artifacts.items():
    if art.op_name:
      # If the param has a valid op_name, this should be an artifact produced
      # by an upstream task.
      pipeline_task_spec.inputs.artifacts[input_name].CopyFrom(
          pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec(
              task_output_artifact=pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec.TaskOutputArtifactSpec(
                  producer_task=dsl_utils.sanitize_task_name(art.op_name),
                  output_artifact_key=art.name)))
    else:
      # Otherwise, this should be from the input of the subdag.
      pipeline_task_spec.inputs.artifacts[input_name].CopyFrom(
          pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec(
              component_input_artifact=art.name
          ))

  # TODO: Add task dependencies/trigger policies/caching/iterator
  pipeline_task_spec.component_ref.name = dsl_utils.sanitize_component_name(
      task_name)

  # Construct dummy I/O declaration for the op.
  # TODO: resolve name conflict instead of raising errors.
  dummy_outputs = collections.OrderedDict()
  for output_name, _ in output_artifacts.items():
    dummy_outputs[output_name] = _DUMMY_PATH

  for output_name, _ in output_parameters.items():
    if output_name in dummy_outputs:
      raise KeyError('Got name collision for output key %s. Consider renaming '
                     'either output parameters or output '
                     'artifacts.' % output_name)
    dummy_outputs[output_name] = _DUMMY_PATH

  dummy_inputs = collections.OrderedDict()
  for input_name, art in input_artifacts.items():
    dummy_inputs[input_name] = _DUMMY_PATH
  for input_name, param in input_parameters.items():
    if input_name in dummy_inputs:
      raise KeyError('Got name collision for input key %s. Consider renaming '
                     'either input parameters or input '
                     'artifacts.' % input_name)
    dummy_inputs[input_name] = _DUMMY_PATH

  # Construct the AIP (Unified) custom job op.
  return AiPlatformCustomJobOp(
      name=task_name,
      custom_job_spec=job_spec,
      component_spec=pipeline_component_spec,
      task_spec=pipeline_task_spec,
      task_inputs=[
          dsl.InputArgumentPath(
              argument=dummy_inputs[input_name],
              input=input_name,
              path=path,
          ) for input_name, path in dummy_inputs.items()
      ],
      task_outputs=dummy_outputs
  )