def test_get_ir_value(self): self.assertDictEqual( json_format.MessageToDict(pipeline_spec_pb2.Value(int_value=42)), json_format.MessageToDict(dsl_utils.get_value(42))) self.assertDictEqual( json_format.MessageToDict(pipeline_spec_pb2.Value(double_value=12.2)), json_format.MessageToDict(dsl_utils.get_value(12.2))) self.assertDictEqual( json_format.MessageToDict( pipeline_spec_pb2.Value(string_value='hello world')), json_format.MessageToDict(dsl_utils.get_value('hello world'))) with self.assertRaisesRegex(TypeError, 'Got unexpected type'): dsl_utils.get_value(_DummyClass())
def _get_custom_job_op( task_name: str, job_spec: Dict[str, Any], input_artifacts: Optional[Dict[str, dsl.PipelineParam]] = None, input_parameters: Optional[Dict[str, _ValueOrPipelineParam]] = None, output_artifacts: Optional[Dict[str, Type[artifact.Artifact]]] = None, output_parameters: Optional[Dict[str, Any]] = None, ) -> AiPlatformCustomJobOp: """Gets an AiPlatformCustomJobOp from job spec and I/O definition.""" pipeline_task_spec = pipeline_spec_pb2.PipelineTaskSpec() pipeline_component_spec = pipeline_spec_pb2.ComponentSpec() pipeline_task_spec.task_info.CopyFrom( pipeline_spec_pb2.PipelineTaskInfo(name=dsl_utils.sanitize_task_name(task_name))) # Iterate through the inputs/outputs declaration to get pipeline component # spec. for input_name, param in input_parameters.items(): if isinstance(param, dsl.PipelineParam): pipeline_component_spec.input_definitions.parameters[ input_name].type = type_utils.get_parameter_type(param.param_type) else: pipeline_component_spec.input_definitions.parameters[ input_name].type = type_utils.get_parameter_type(type(param)) for input_name, art in input_artifacts.items(): if not isinstance(art, dsl.PipelineParam): raise RuntimeError( 'Get unresolved input artifact for input %s. Input ' 'artifacts must be connected to a producer task.' % input_name) pipeline_component_spec.input_definitions.artifacts[ input_name].artifact_type.CopyFrom( type_utils.get_artifact_type_schema_message(art.param_type)) for output_name, param_type in output_parameters.items(): pipeline_component_spec.output_definitions.parameters[ output_name].type = type_utils.get_parameter_type(param_type) for output_name, artifact_type in output_artifacts.items(): pipeline_component_spec.output_definitions.artifacts[ output_name].artifact_type.CopyFrom(artifact_type.get_ir_type()) pipeline_component_spec.executor_label = dsl_utils.sanitize_executor_label( task_name) # Iterate through the inputs/outputs specs to get pipeline task spec. for input_name, param in input_parameters.items(): if isinstance(param, dsl.PipelineParam) and param.op_name: # If the param has a valid op_name, this should be a pipeline parameter # produced by an upstream task. pipeline_task_spec.inputs.parameters[input_name].CopyFrom( pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec( task_output_parameter=pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec.TaskOutputParameterSpec( producer_task=dsl_utils.sanitize_task_name(param.op_name), output_parameter_key=param.name ))) elif isinstance(param, dsl.PipelineParam) and not param.op_name: # If a valid op_name is missing, this should be a pipeline parameter. pipeline_task_spec.inputs.parameters[input_name].CopyFrom( pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec( component_input_parameter=param.name)) else: # If this is not a pipeline param, then it should be a value. pipeline_task_spec.inputs.parameters[input_name].CopyFrom( pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec( runtime_value=pipeline_spec_pb2.ValueOrRuntimeParameter( constant_value=dsl_utils.get_value(param)))) for input_name, art in input_artifacts.items(): if art.op_name: # If the param has a valid op_name, this should be an artifact produced # by an upstream task. pipeline_task_spec.inputs.artifacts[input_name].CopyFrom( pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec( task_output_artifact=pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec.TaskOutputArtifactSpec( producer_task=dsl_utils.sanitize_task_name(art.op_name), output_artifact_key=art.name))) else: # Otherwise, this should be from the input of the subdag. pipeline_task_spec.inputs.artifacts[input_name].CopyFrom( pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec( component_input_artifact=art.name )) # TODO: Add task dependencies/trigger policies/caching/iterator pipeline_task_spec.component_ref.name = dsl_utils.sanitize_component_name( task_name) # Construct dummy I/O declaration for the op. # TODO: resolve name conflict instead of raising errors. dummy_outputs = collections.OrderedDict() for output_name, _ in output_artifacts.items(): dummy_outputs[output_name] = _DUMMY_PATH for output_name, _ in output_parameters.items(): if output_name in dummy_outputs: raise KeyError('Got name collision for output key %s. Consider renaming ' 'either output parameters or output ' 'artifacts.' % output_name) dummy_outputs[output_name] = _DUMMY_PATH dummy_inputs = collections.OrderedDict() for input_name, art in input_artifacts.items(): dummy_inputs[input_name] = _DUMMY_PATH for input_name, param in input_parameters.items(): if input_name in dummy_inputs: raise KeyError('Got name collision for input key %s. Consider renaming ' 'either input parameters or input ' 'artifacts.' % input_name) dummy_inputs[input_name] = _DUMMY_PATH # Construct the AIP (Unified) custom job op. return AiPlatformCustomJobOp( name=task_name, custom_job_spec=job_spec, component_spec=pipeline_component_spec, task_spec=pipeline_task_spec, task_inputs=[ dsl.InputArgumentPath( argument=dummy_inputs[input_name], input=input_name, path=path, ) for input_name, path in dummy_inputs.items() ], task_outputs=dummy_outputs )