Exemplo n.º 1
0
    def run(self,
            pipeline: tfx_pipeline.Pipeline,
            parameter_values: Optional[Dict[Text, Any]] = None,
            write_out: Optional[bool] = True) -> Dict[Text, Any]:
        """Compiles a pipeline DSL object into pipeline file.

    Args:
      pipeline: TFX pipeline object.
      parameter_values: mapping from runtime parameter names to its values.
      write_out: set to True to actually write out the file to the place
        designated by output_dir and output_filename. Otherwise return the
        JSON-serialized pipeline job spec.

    Returns:
      Returns the JSON pipeline job spec.

    Raises:
      RuntimeError: if trying to write out to a place occupied by an existing
      file.
    """
        # TODO(b/166343606): Support user-provided labels.
        # TODO(b/169095387): Deprecate .run() method in favor of the unified API
        # client.
        display_name = (self._config.display_name
                        or pipeline.pipeline_info.pipeline_name)
        pipeline_spec = pipeline_builder.PipelineBuilder(
            tfx_pipeline=pipeline,
            default_image=self._config.default_image,
            default_commands=self._config.default_commands).build()
        pipeline_spec.sdk_version = 'tfx-{}'.format(version.__version__)
        pipeline_spec.schema_version = _SCHEMA_VERSION
        runtime_config = pipeline_builder.RuntimeConfigBuilder(
            pipeline_info=pipeline.pipeline_info,
            parameter_values=parameter_values).build()
        with telemetry_utils.scoped_labels(
            {telemetry_utils.LABEL_TFX_RUNNER: 'kubeflow_v2'}):
            result = pipeline_spec_pb2.PipelineJob(
                display_name=display_name
                or pipeline.pipeline_info.pipeline_name,
                labels=telemetry_utils.get_labels_dict(),
                runtime_config=runtime_config)
        result.pipeline_spec.update(json_format.MessageToDict(pipeline_spec))
        pipeline_json_dict = json_format.MessageToDict(result)
        if write_out:
            if fileio.exists(
                    self._output_dir) and not fileio.isdir(self._output_dir):
                raise RuntimeError('Output path: %s is pointed to a file.' %
                                   self._output_dir)
            if not fileio.exists(self._output_dir):
                fileio.makedirs(self._output_dir)

            with fileio.open(
                    os.path.join(self._output_dir, self._output_filename),
                    'wb') as f:
                f.write(json.dumps(pipeline_json_dict, sort_keys=True))

        return pipeline_json_dict
Exemplo n.º 2
0
 def testBuildTwoStepPipeline(self):
   my_builder = pipeline_builder.PipelineBuilder(
       tfx_pipeline=test_utils.two_step_pipeline(),
       default_image='gcr.io/my-tfx:latest')
   actual_pipeline_spec = my_builder.build()
   self.assertProtoEquals(
       test_utils.get_proto_from_test_data('expected_two_step_pipeline.pbtxt',
                                           pipeline_pb2.PipelineSpec()),
       actual_pipeline_spec)
Exemplo n.º 3
0
 def testBuildPipelineWithPrimitiveValuePassing(self):
   my_builder = pipeline_builder.PipelineBuilder(
       tfx_pipeline=test_utils.consume_primitive_artifacts_by_value_pipeline(),
       default_image='gcr.io/my-tfx:latest')
   actual_pipeline_spec = my_builder.build()
   self.assertProtoEquals(
       test_utils.get_proto_from_test_data(
           'expected_consume_primitive_artifacts_by_value_pipeline.pbtxt',
           pipeline_pb2.PipelineSpec()), actual_pipeline_spec)
Exemplo n.º 4
0
 def testBuildPipelineWithRuntimeParameter(self):
   my_builder = pipeline_builder.PipelineBuilder(
       tfx_pipeline=test_utils.pipeline_with_runtime_parameter(),
       default_image='gcr.io/my-tfx:latest')
   actual_pipeline_spec = my_builder.build()
   self.assertProtoEquals(
       test_utils.get_proto_from_test_data(
           'expected_pipeline_with_runtime_parameter.pbtxt',
           pipeline_pb2.PipelineSpec()), actual_pipeline_spec)
Exemplo n.º 5
0
  def testBuildPipelineWithTwoContainerSpecComponents(self):
    my_builder = pipeline_builder.PipelineBuilder(
        tfx_pipeline=test_utils.pipeline_with_two_container_spec_components(),
        default_image='gcr.io/my-tfx:latest')
    actual_pipeline_spec = my_builder.build()

    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_pipeline_with_two_container_spec_components.pbtxt',
            pipeline_pb2.PipelineSpec()), actual_pipeline_spec)
Exemplo n.º 6
0
  def testTwoStepPipelineWithTaskOnlyDependency(self):
    builder = pipeline_builder.PipelineBuilder(
        tfx_pipeline=test_utils.two_step_pipeline_with_task_only_dependency(),
        default_image='unused-image')

    pipeline_spec = builder.build()
    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_two_step_pipeline_with_task_only_dependency.pbtxt',
            pipeline_pb2.PipelineSpec()), pipeline_spec)
Exemplo n.º 7
0
  def testBuildTwoStepPipelineWithCacheEnabled(self):
    pipeline = test_utils.two_step_pipeline()
    pipeline.enable_cache = True

    builder = pipeline_builder.PipelineBuilder(
        tfx_pipeline=pipeline, default_image='gcr.io/my-tfx:latest')
    pipeline_spec = builder.build()
    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_two_step_pipeline_with_cache_enabled.pbtxt',
            pipeline_pb2.PipelineSpec()), pipeline_spec)
Exemplo n.º 8
0
  def testPipelineWithExitHandler(self):
    pipeline = test_utils.two_step_pipeline()
    # define exit handler
    exit_handler = test_utils.dummy_exit_handler(
        param1=decorators.FinalStatusStr())

    builder = pipeline_builder.PipelineBuilder(
        tfx_pipeline=pipeline,
        default_image='gcr.io/my-tfx:latest',
        exit_handler=exit_handler)
    pipeline_spec = builder.build()
    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_two_step_pipeline_with_exit_handler.pbtxt',
            pipeline_pb2.PipelineSpec()), pipeline_spec)