コード例 #1
0
ファイル: step_builder_test.py プロジェクト: jay90099/tfx
    def testBuildImporterWithRuntimeParam(self):
        param = data_types.RuntimeParameter(name='runtime_flag', ptype=str)
        impt = importer.Importer(
            source_uri=param,
            artifact_type=standard_artifacts.Examples).with_id('my_importer')
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        with parameter_utils.ParameterContext() as pc:
            my_builder = step_builder.StepBuilder(
                node=impt,
                deployment_config=deployment_config,
                component_defs=component_defs)
            actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_component_with_runtime_param.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_task_with_runtime_param.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_importer_executor_with_runtime_param.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
        self.assertListEqual([param], pc.parameters)
コード例 #2
0
 def _testAttachParametersInSingleThread(self, suffix: Text):
     with parameter_utils.ParameterContext() as pc:
         parameter_utils.attach_parameter(
             data_types.RuntimeParameter(name='param1_in_{}'.format(suffix),
                                         ptype=int))
         parameter_utils.attach_parameter(
             data_types.RuntimeParameter(name='param2_in_{}'.format(suffix),
                                         ptype=int))
     self.assertLen(pc.parameters, 2)
     self.assertEqual(pc.parameters[0].name, 'param1_in_{}'.format(suffix))
     self.assertEqual(pc.parameters[1].name, 'param2_in_{}'.format(suffix))
コード例 #3
0
    def testAttachParameters(self):
        with parameter_utils.ParameterContext() as pc:
            param1 = data_types.RuntimeParameter(name='test_param_1',
                                                 ptype=int)
            parameter_utils.attach_parameter(param1)
            param2 = data_types.RuntimeParameter(name='test_param_2',
                                                 ptype=Text)
            parameter_utils.attach_parameter(param2)
            param3 = data_types.RuntimeParameter(name='test_param_3',
                                                 ptype=float)
            parameter_utils.attach_parameter(param3)

        self.assertListEqual([param1, param2, param3], pc.parameters)
コード例 #4
0
ファイル: pipeline_builder.py プロジェクト: konny0311/tfx
    def build(self) -> pipeline_pb2.PipelineSpec:
        """Build a pipeline PipelineSpec."""

        _check_name(self._pipeline_info.pipeline_name)

        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        pipeline_info = pipeline_pb2.PipelineInfo(
            name=self._pipeline_info.pipeline_name)

        tasks = {}
        component_defs = {}
        # Map from (producer component id, output key) to (new producer component
        # id, output key)
        channel_redirect_map = {}
        with parameter_utils.ParameterContext() as pc:
            for component in self._pipeline.components:
                # Here the topological order of components is required.
                # If a channel redirection is needed, redirect mapping is expected to be
                # available because the upstream node (which is the cause for
                # redirecting) is processed before the downstream consumer nodes.
                built_tasks = step_builder.StepBuilder(
                    node=component,
                    deployment_config=deployment_config,
                    component_defs=component_defs,
                    image=self._default_image,
                    image_cmds=self._default_commands,
                    beam_pipeline_args=self._pipeline.beam_pipeline_args,
                    enable_cache=self._pipeline.enable_cache,
                    pipeline_info=self._pipeline_info,
                    channel_redirect_map=channel_redirect_map).build()
                tasks.update(built_tasks)

        result = pipeline_pb2.PipelineSpec(pipeline_info=pipeline_info)
        result.deployment_spec.update(
            json_format.MessageToDict(deployment_config))
        for name, component_def in component_defs.items():
            result.components[name].CopyFrom(component_def)
        for name, task_spec in tasks.items():
            result.root.dag.tasks[name].CopyFrom(task_spec)

        # Attach runtime parameter to root's input parameter
        for param in pc.parameters:
            result.root.input_definitions.parameters[param.name].CopyFrom(
                compiler_utils.build_parameter_type_spec(param))

        return result
コード例 #5
0
ファイル: pipeline_builder.py プロジェクト: vikrosj/tfx
    def build(self) -> pipeline_pb2.PipelineSpec:
        """Build a pipeline PipelineSpec."""

        _check_name(self._pipeline_info.pipeline_name)

        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        pipeline_info = pipeline_pb2.PipelineInfo(
            name=self._pipeline_info.pipeline_name)

        tasks = []
        # Map from (producer component id, output key) to (new producer component
        # id, output key)
        channel_redirect_map = {}
        with parameter_utils.ParameterContext() as pc:
            for component in self._pipeline.components:
                # Here the topological order of components is required.
                # If a channel redirection is needed, redirect mapping is expected to be
                # available because the upstream node (which is the cause for
                # redirecting) is processed before the downstream consumer nodes.
                built_tasks = step_builder.StepBuilder(
                    node=component,
                    deployment_config=deployment_config,
                    image=self._default_image,
                    image_cmds=self._default_commands,
                    beam_pipeline_args=self._pipeline.beam_pipeline_args,
                    enable_cache=self._pipeline.enable_cache,
                    pipeline_info=self._pipeline_info,
                    channel_redirect_map=channel_redirect_map).build()
                tasks.extend(built_tasks)

        result = pipeline_pb2.PipelineSpec(
            pipeline_info=pipeline_info,
            tasks=tasks,
            runtime_parameters=compiler_utils.build_runtime_parameter_spec(
                pc.parameters))
        result.deployment_spec.update(
            json_format.MessageToDict(deployment_config))

        return result
コード例 #6
0
    def build(self) -> pipeline_pb2.PipelineSpec:
        """Build a pipeline PipelineSpec."""

        _check_name(self._pipeline_info.pipeline_name)

        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        pipeline_info = pipeline_pb2.PipelineInfo(
            name=self._pipeline_info.pipeline_name)

        tfx_tasks = {}
        component_defs = {}
        # Map from (producer component id, output key) to (new producer component
        # id, output key)
        channel_redirect_map = {}
        with parameter_utils.ParameterContext() as pc:
            for component in self._pipeline.components:
                if self._exit_handler and component.id == compiler_utils.TFX_DAG_NAME:
                    component.with_id(component.id +
                                      _generate_component_name_suffix())
                    logging.warning(
                        '_tfx_dag is system reserved name for pipeline with'
                        'exit handler, added suffix to your component name: %s',
                        component.id)
                # Here the topological order of components is required.
                # If a channel redirection is needed, redirect mapping is expected to be
                # available because the upstream node (which is the cause for
                # redirecting) is processed before the downstream consumer nodes.
                built_tasks = step_builder.StepBuilder(
                    node=component,
                    deployment_config=deployment_config,
                    component_defs=component_defs,
                    image=self._default_image,
                    image_cmds=self._default_commands,
                    beam_pipeline_args=self._pipeline.beam_pipeline_args,
                    enable_cache=self._pipeline.enable_cache,
                    pipeline_info=self._pipeline_info,
                    channel_redirect_map=channel_redirect_map).build()
                tfx_tasks.update(built_tasks)

        result = pipeline_pb2.PipelineSpec(pipeline_info=pipeline_info)

        # if exit handler is defined, put all the TFX tasks under tfx_dag,
        # exit handler is a separate component triggered by tfx_dag.
        if self._exit_handler:
            for name, task_spec in tfx_tasks.items():
                result.components[compiler_utils.TFX_DAG_NAME].dag.tasks[
                    name].CopyFrom(task_spec)
            # construct root with exit handler
            exit_handler_task = step_builder.StepBuilder(
                node=self._exit_handler,
                deployment_config=deployment_config,
                component_defs=component_defs,
                image=self._default_image,
                image_cmds=self._default_commands,
                beam_pipeline_args=self._pipeline.beam_pipeline_args,
                enable_cache=False,
                pipeline_info=self._pipeline_info,
                channel_redirect_map=channel_redirect_map,
                is_exit_handler=True).build()
            result.root.dag.tasks[
                compiler_utils.
                TFX_DAG_NAME].component_ref.name = compiler_utils.TFX_DAG_NAME
            result.root.dag.tasks[
                compiler_utils.
                TFX_DAG_NAME].task_info.name = compiler_utils.TFX_DAG_NAME
            result.root.dag.tasks[self._exit_handler.id].CopyFrom(
                exit_handler_task[self._exit_handler.id])
        else:
            for name, task_spec in tfx_tasks.items():
                result.root.dag.tasks[name].CopyFrom(task_spec)

        result.deployment_spec.update(
            json_format.MessageToDict(deployment_config))
        for name, component_def in component_defs.items():
            result.components[name].CopyFrom(component_def)

        # Attach runtime parameter to root's input parameter
        for param in pc.parameters:
            result.root.input_definitions.parameters[param.name].CopyFrom(
                compiler_utils.build_parameter_type_spec(param))

        return result