Esempio n. 1
0
def _mock_subprocess_call(cmd: Sequence[Optional[Text]],
                          env: Mapping[Text, Text]) -> int:
  """Mocks the subprocess call."""
  assert len(cmd) == 2, 'Unexpected number of commands: {}'.format(cmd)
  del env
  dsl_path = cmd[1]

  if dsl_path.endswith('test_pipeline_bad.py'):
    sys.exit(1)
  if not dsl_path.endswith(
      'test_pipeline_1.py') and not dsl_path.endswith(
          'test_pipeline_2.py'):
    raise ValueError('Unexpected dsl path: {}'.format(dsl_path))

  spec_pb = pipeline_pb2.PipelineSpec(
      pipeline_info=pipeline_pb2.PipelineInfo(name='chicago_taxi_kubeflow'))
  runtime_pb = pipeline_pb2.PipelineJob.RuntimeConfig(
      gcs_output_directory=os.path.join(os.environ['HOME'], 'tfx', 'pipelines',
                                        'chicago_taxi_kubeflow'))
  job_pb = pipeline_pb2.PipelineJob(runtime_config=runtime_pb)
  job_pb.pipeline_spec.update(json_format.MessageToDict(spec_pb))
  io_utils.write_string_file(
      file_name='pipeline.json',
      string_value=json_format.MessageToJson(message=job_pb, sort_keys=True))
  return 0
Esempio n. 2
0
 def testBuildTwoStepPipeline(self):
   my_builder = pipeline_builder.PipelineBuilder(
       tfx_pipeline=test_utils.two_step_pipeline(),
       default_image='gcr.io/my-tfx:latest')
   actual_pipeline_spec = my_builder.build()
   self.assertProtoEquals(
       test_utils.get_proto_from_test_data('expected_two_step_pipeline.pbtxt',
                                           pipeline_pb2.PipelineSpec()),
       actual_pipeline_spec)
Esempio n. 3
0
 def testBuildPipelineWithPrimitiveValuePassing(self):
   my_builder = pipeline_builder.PipelineBuilder(
       tfx_pipeline=test_utils.consume_primitive_artifacts_by_value_pipeline(),
       default_image='gcr.io/my-tfx:latest')
   actual_pipeline_spec = my_builder.build()
   self.assertProtoEquals(
       test_utils.get_proto_from_test_data(
           'expected_consume_primitive_artifacts_by_value_pipeline.pbtxt',
           pipeline_pb2.PipelineSpec()), actual_pipeline_spec)
Esempio n. 4
0
 def testBuildPipelineWithRuntimeParameter(self):
   my_builder = pipeline_builder.PipelineBuilder(
       tfx_pipeline=test_utils.pipeline_with_runtime_parameter(),
       default_image='gcr.io/my-tfx:latest')
   actual_pipeline_spec = my_builder.build()
   self.assertProtoEquals(
       test_utils.get_proto_from_test_data(
           'expected_pipeline_with_runtime_parameter.pbtxt',
           pipeline_pb2.PipelineSpec()), actual_pipeline_spec)
Esempio n. 5
0
  def testBuildPipelineWithTwoContainerSpecComponents(self):
    my_builder = pipeline_builder.PipelineBuilder(
        tfx_pipeline=test_utils.pipeline_with_two_container_spec_components(),
        default_image='gcr.io/my-tfx:latest')
    actual_pipeline_spec = my_builder.build()

    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_pipeline_with_two_container_spec_components.pbtxt',
            pipeline_pb2.PipelineSpec()), actual_pipeline_spec)
Esempio n. 6
0
  def testTwoStepPipelineWithTaskOnlyDependency(self):
    builder = pipeline_builder.PipelineBuilder(
        tfx_pipeline=test_utils.two_step_pipeline_with_task_only_dependency(),
        default_image='unused-image')

    pipeline_spec = builder.build()
    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_two_step_pipeline_with_task_only_dependency.pbtxt',
            pipeline_pb2.PipelineSpec()), pipeline_spec)
Esempio n. 7
0
  def testBuildTwoStepPipelineWithCacheEnabled(self):
    pipeline = test_utils.two_step_pipeline()
    pipeline.enable_cache = True

    builder = pipeline_builder.PipelineBuilder(
        tfx_pipeline=pipeline, default_image='gcr.io/my-tfx:latest')
    pipeline_spec = builder.build()
    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_two_step_pipeline_with_cache_enabled.pbtxt',
            pipeline_pb2.PipelineSpec()), pipeline_spec)
Esempio n. 8
0
    def _extract_pipeline_args(self) -> Dict[Text, Any]:
        """Get pipeline args from the DSL by compiling the pipeline.

    Returns:
      Python dictionary with pipeline details extracted from DSL.

    Raises:
      RuntimeError: when the given pipeline arg file location is occupied.
    """
        pipeline_dsl_path = self.flags_dict[labels.PIPELINE_DSL_PATH]

        if os.path.isdir(pipeline_dsl_path):
            sys.exit('Provide a valid dsl file path.')

        # Create an environment for subprocess.
        temp_env = os.environ.copy()

        # We don't need image name and project ID for extracting pipeline info,
        # so they can be optional.
        runner_env = {
            kubeflow_labels.TFX_IMAGE_ENV:
            self.flags_dict.get(kubeflow_labels.TFX_IMAGE_ENV, ''),
            kubeflow_labels.GCP_PROJECT_ID_ENV:
            self.flags_dict.get(kubeflow_labels.GCP_PROJECT_ID_ENV, ''),
        }

        temp_env.update(runner_env)

        # Run pipeline dsl. Note that here because we don't have RUN_FLAG_ENV
        # the actual execution won't be triggered. Instead the DSL will output a
        # compiled pipeline spec.
        self._subprocess_call(command=[sys.executable, pipeline_dsl_path],
                              env=temp_env)

        # Only import pipeline_pb2 when needed to guard CLI dependency.
        from tfx.orchestration.kubeflow.v2.proto import pipeline_pb2  # pylint: disable=g-import-not-at-top

        # Extract the needed information from compiled pipeline spec.
        job_message = pipeline_pb2.PipelineJob()
        io_utils.parse_json_file(file_name=os.path.join(
            os.getcwd(), _PIPELINE_SPEC_FILE),
                                 message=job_message)

        pipeline_spec_pb = json_format.ParseDict(job_message.pipeline_spec,
                                                 pipeline_pb2.PipelineSpec())

        pipeline_name = pipeline_spec_pb.pipeline_info.name
        pipeline_args = {
            'pipeline_name': pipeline_name,
            'pipeline_root': job_message.runtime_config.gcs_output_directory
        }

        return pipeline_args
Esempio n. 9
0
    def build(self) -> pipeline_pb2.PipelineSpec:
        """Build a pipeline PipelineSpec."""

        _check_name(self._pipeline_info.pipeline_name)

        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        pipeline_info = pipeline_pb2.PipelineInfo(
            name=self._pipeline_info.pipeline_name)

        tasks = {}
        component_defs = {}
        # Map from (producer component id, output key) to (new producer component
        # id, output key)
        channel_redirect_map = {}
        with parameter_utils.ParameterContext() as pc:
            for component in self._pipeline.components:
                # Here the topological order of components is required.
                # If a channel redirection is needed, redirect mapping is expected to be
                # available because the upstream node (which is the cause for
                # redirecting) is processed before the downstream consumer nodes.
                built_tasks = step_builder.StepBuilder(
                    node=component,
                    deployment_config=deployment_config,
                    component_defs=component_defs,
                    image=self._default_image,
                    image_cmds=self._default_commands,
                    beam_pipeline_args=self._pipeline.beam_pipeline_args,
                    enable_cache=self._pipeline.enable_cache,
                    pipeline_info=self._pipeline_info,
                    channel_redirect_map=channel_redirect_map).build()
                tasks.update(built_tasks)

        result = pipeline_pb2.PipelineSpec(pipeline_info=pipeline_info)
        result.deployment_spec.update(
            json_format.MessageToDict(deployment_config))
        for name, component_def in component_defs.items():
            result.components[name].CopyFrom(component_def)
        for name, task_spec in tasks.items():
            result.root.dag.tasks[name].CopyFrom(task_spec)

        # Attach runtime parameter to root's input parameter
        for param in pc.parameters:
            result.root.input_definitions.parameters[param.name].CopyFrom(
                compiler_utils.build_parameter_type_spec(param))

        return result
Esempio n. 10
0
    def build(self) -> pipeline_pb2.PipelineSpec:
        """Build a pipeline PipelineSpec."""

        _check_name(self._pipeline_info.pipeline_name)

        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        pipeline_info = pipeline_pb2.PipelineInfo(
            name=self._pipeline_info.pipeline_name)

        tasks = []
        # Map from (producer component id, output key) to (new producer component
        # id, output key)
        channel_redirect_map = {}
        with parameter_utils.ParameterContext() as pc:
            for component in self._pipeline.components:
                # Here the topological order of components is required.
                # If a channel redirection is needed, redirect mapping is expected to be
                # available because the upstream node (which is the cause for
                # redirecting) is processed before the downstream consumer nodes.
                built_tasks = step_builder.StepBuilder(
                    node=component,
                    deployment_config=deployment_config,
                    image=self._default_image,
                    image_cmds=self._default_commands,
                    beam_pipeline_args=self._pipeline.beam_pipeline_args,
                    enable_cache=self._pipeline.enable_cache,
                    pipeline_info=self._pipeline_info,
                    channel_redirect_map=channel_redirect_map).build()
                tasks.extend(built_tasks)

        result = pipeline_pb2.PipelineSpec(
            pipeline_info=pipeline_info,
            tasks=tasks,
            runtime_parameters=compiler_utils.build_runtime_parameter_spec(
                pc.parameters))
        result.deployment_spec.update(
            json_format.MessageToDict(deployment_config))

        return result