def testBuildImporterWithRuntimeParam(self): param = data_types.RuntimeParameter(name='runtime_flag', ptype=str) impt = importer.Importer( source_uri=param, artifact_type=standard_artifacts.Examples).with_id('my_importer') deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} with parameter_utils.ParameterContext() as pc: my_builder = step_builder.StepBuilder( node=impt, deployment_config=deployment_config, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_importer_component_with_runtime_param.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_importer_task_with_runtime_param.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_importer_executor_with_runtime_param.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config) self.assertListEqual([param], pc.parameters)
def _testAttachParametersInSingleThread(self, suffix: Text): with parameter_utils.ParameterContext() as pc: parameter_utils.attach_parameter( data_types.RuntimeParameter(name='param1_in_{}'.format(suffix), ptype=int)) parameter_utils.attach_parameter( data_types.RuntimeParameter(name='param2_in_{}'.format(suffix), ptype=int)) self.assertLen(pc.parameters, 2) self.assertEqual(pc.parameters[0].name, 'param1_in_{}'.format(suffix)) self.assertEqual(pc.parameters[1].name, 'param2_in_{}'.format(suffix))
def testAttachParameters(self): with parameter_utils.ParameterContext() as pc: param1 = data_types.RuntimeParameter(name='test_param_1', ptype=int) parameter_utils.attach_parameter(param1) param2 = data_types.RuntimeParameter(name='test_param_2', ptype=Text) parameter_utils.attach_parameter(param2) param3 = data_types.RuntimeParameter(name='test_param_3', ptype=float) parameter_utils.attach_parameter(param3) self.assertListEqual([param1, param2, param3], pc.parameters)
def build(self) -> pipeline_pb2.PipelineSpec: """Build a pipeline PipelineSpec.""" _check_name(self._pipeline_info.pipeline_name) deployment_config = pipeline_pb2.PipelineDeploymentConfig() pipeline_info = pipeline_pb2.PipelineInfo( name=self._pipeline_info.pipeline_name) tasks = {} component_defs = {} # Map from (producer component id, output key) to (new producer component # id, output key) channel_redirect_map = {} with parameter_utils.ParameterContext() as pc: for component in self._pipeline.components: # Here the topological order of components is required. # If a channel redirection is needed, redirect mapping is expected to be # available because the upstream node (which is the cause for # redirecting) is processed before the downstream consumer nodes. built_tasks = step_builder.StepBuilder( node=component, deployment_config=deployment_config, component_defs=component_defs, image=self._default_image, image_cmds=self._default_commands, beam_pipeline_args=self._pipeline.beam_pipeline_args, enable_cache=self._pipeline.enable_cache, pipeline_info=self._pipeline_info, channel_redirect_map=channel_redirect_map).build() tasks.update(built_tasks) result = pipeline_pb2.PipelineSpec(pipeline_info=pipeline_info) result.deployment_spec.update( json_format.MessageToDict(deployment_config)) for name, component_def in component_defs.items(): result.components[name].CopyFrom(component_def) for name, task_spec in tasks.items(): result.root.dag.tasks[name].CopyFrom(task_spec) # Attach runtime parameter to root's input parameter for param in pc.parameters: result.root.input_definitions.parameters[param.name].CopyFrom( compiler_utils.build_parameter_type_spec(param)) return result
def build(self) -> pipeline_pb2.PipelineSpec: """Build a pipeline PipelineSpec.""" _check_name(self._pipeline_info.pipeline_name) deployment_config = pipeline_pb2.PipelineDeploymentConfig() pipeline_info = pipeline_pb2.PipelineInfo( name=self._pipeline_info.pipeline_name) tasks = [] # Map from (producer component id, output key) to (new producer component # id, output key) channel_redirect_map = {} with parameter_utils.ParameterContext() as pc: for component in self._pipeline.components: # Here the topological order of components is required. # If a channel redirection is needed, redirect mapping is expected to be # available because the upstream node (which is the cause for # redirecting) is processed before the downstream consumer nodes. built_tasks = step_builder.StepBuilder( node=component, deployment_config=deployment_config, image=self._default_image, image_cmds=self._default_commands, beam_pipeline_args=self._pipeline.beam_pipeline_args, enable_cache=self._pipeline.enable_cache, pipeline_info=self._pipeline_info, channel_redirect_map=channel_redirect_map).build() tasks.extend(built_tasks) result = pipeline_pb2.PipelineSpec( pipeline_info=pipeline_info, tasks=tasks, runtime_parameters=compiler_utils.build_runtime_parameter_spec( pc.parameters)) result.deployment_spec.update( json_format.MessageToDict(deployment_config)) return result
def build(self) -> pipeline_pb2.PipelineSpec: """Build a pipeline PipelineSpec.""" _check_name(self._pipeline_info.pipeline_name) deployment_config = pipeline_pb2.PipelineDeploymentConfig() pipeline_info = pipeline_pb2.PipelineInfo( name=self._pipeline_info.pipeline_name) tfx_tasks = {} component_defs = {} # Map from (producer component id, output key) to (new producer component # id, output key) channel_redirect_map = {} with parameter_utils.ParameterContext() as pc: for component in self._pipeline.components: if self._exit_handler and component.id == compiler_utils.TFX_DAG_NAME: component.with_id(component.id + _generate_component_name_suffix()) logging.warning( '_tfx_dag is system reserved name for pipeline with' 'exit handler, added suffix to your component name: %s', component.id) # Here the topological order of components is required. # If a channel redirection is needed, redirect mapping is expected to be # available because the upstream node (which is the cause for # redirecting) is processed before the downstream consumer nodes. built_tasks = step_builder.StepBuilder( node=component, deployment_config=deployment_config, component_defs=component_defs, image=self._default_image, image_cmds=self._default_commands, beam_pipeline_args=self._pipeline.beam_pipeline_args, enable_cache=self._pipeline.enable_cache, pipeline_info=self._pipeline_info, channel_redirect_map=channel_redirect_map).build() tfx_tasks.update(built_tasks) result = pipeline_pb2.PipelineSpec(pipeline_info=pipeline_info) # if exit handler is defined, put all the TFX tasks under tfx_dag, # exit handler is a separate component triggered by tfx_dag. if self._exit_handler: for name, task_spec in tfx_tasks.items(): result.components[compiler_utils.TFX_DAG_NAME].dag.tasks[ name].CopyFrom(task_spec) # construct root with exit handler exit_handler_task = step_builder.StepBuilder( node=self._exit_handler, deployment_config=deployment_config, component_defs=component_defs, image=self._default_image, image_cmds=self._default_commands, beam_pipeline_args=self._pipeline.beam_pipeline_args, enable_cache=False, pipeline_info=self._pipeline_info, channel_redirect_map=channel_redirect_map, is_exit_handler=True).build() result.root.dag.tasks[ compiler_utils. TFX_DAG_NAME].component_ref.name = compiler_utils.TFX_DAG_NAME result.root.dag.tasks[ compiler_utils. TFX_DAG_NAME].task_info.name = compiler_utils.TFX_DAG_NAME result.root.dag.tasks[self._exit_handler.id].CopyFrom( exit_handler_task[self._exit_handler.id]) else: for name, task_spec in tfx_tasks.items(): result.root.dag.tasks[name].CopyFrom(task_spec) result.deployment_spec.update( json_format.MessageToDict(deployment_config)) for name, component_def in component_defs.items(): result.components[name].CopyFrom(component_def) # Attach runtime parameter to root's input parameter for param in pc.parameters: result.root.input_definitions.parameters[param.name].CopyFrom( compiler_utils.build_parameter_type_spec(param)) return result