コード例 #1
0
ファイル: kubeflow_v2_e2e_test.py プロジェクト: vikrosj/tfx
    def testPipeline(self):
        self._copyTemplate()

        # Uncomment all variables in config.
        self._uncommentMultiLineVariables(
            os.path.join('pipeline', 'configs.py'), [
                'GOOGLE_CLOUD_REGION',
                'BIG_QUERY_WITH_DIRECT_RUNNER_BEAM_PIPELINE_ARGS',
                'BIG_QUERY_QUERY', 'DATAFLOW_BEAM_PIPELINE_ARGS',
                'GCP_AI_PLATFORM_TRAINING_ARGS', 'GCP_AI_PLATFORM_SERVING_ARGS'
            ])

        # Prepare data.
        self._prepare_data()
        self._replaceFileContent('kubeflow_v2_dag_runner.py', [
            ('_DATA_PATH = \'gs://{}/tfx-template/data/\'.'
             'format(configs.GCS_BUCKET_NAME)',
             '_DATA_PATH = \'gs://{{}}/{}/{}\'.format(configs.GCS_BUCKET_NAME)'
             .format(self._DATA_DIRECTORY_NAME, self._pipeline_name)),
        ])

        # Create a pipeline with only one component.
        self._create_pipeline()

        # Extract the compiled pipeline spec.
        kubeflow_v2_pb = pipeline_pb2.PipelineJob()
        io_utils.parse_json_file(file_name=os.path.join(
            os.getcwd(), 'pipeline.json'),
                                 message=kubeflow_v2_pb)
        # There should be one step in the compiled pipeline.
        self.assertLen(kubeflow_v2_pb.pipeline_spec['tasks'], 1)
コード例 #2
0
def _mock_subprocess_call(cmd: Sequence[Optional[Text]],
                          env: Mapping[Text, Text]) -> int:
  """Mocks the subprocess call."""
  assert len(cmd) == 2, 'Unexpected number of commands: {}'.format(cmd)
  del env
  dsl_path = cmd[1]

  if dsl_path.endswith('test_pipeline_bad.py'):
    sys.exit(1)
  if not dsl_path.endswith(
      'test_pipeline_1.py') and not dsl_path.endswith(
          'test_pipeline_2.py'):
    raise ValueError('Unexpected dsl path: {}'.format(dsl_path))

  spec_pb = pipeline_pb2.PipelineSpec(
      pipeline_info=pipeline_pb2.PipelineInfo(name='chicago_taxi_kubeflow'))
  runtime_pb = pipeline_pb2.PipelineJob.RuntimeConfig(
      gcs_output_directory=os.path.join(os.environ['HOME'], 'tfx', 'pipelines',
                                        'chicago_taxi_kubeflow'))
  job_pb = pipeline_pb2.PipelineJob(runtime_config=runtime_pb)
  job_pb.pipeline_spec.update(json_format.MessageToDict(spec_pb))
  io_utils.write_string_file(
      file_name='pipeline.json',
      string_value=json_format.MessageToJson(message=job_pb, sort_keys=True))
  return 0
コード例 #3
0
    def run(self,
            pipeline: tfx_pipeline.Pipeline,
            parameter_values: Optional[Dict[Text, Any]] = None,
            write_out: Optional[bool] = True) -> Dict[Text, Any]:
        """Compiles a pipeline DSL object into pipeline file.

    Args:
      pipeline: TFX pipeline object.
      parameter_values: mapping from runtime parameter names to its values.
      write_out: set to True to actually write out the file to the place
        designated by output_dir and output_filename. Otherwise return the
        JSON-serialized pipeline job spec.

    Returns:
      Returns the JSON pipeline job spec.

    Raises:
      RuntimeError: if trying to write out to a place occupied by an existing
      file.
    """
        # TODO(b/166343606): Support user-provided labels.
        # TODO(b/169095387): Deprecate .run() method in favor of the unified API
        # client.
        display_name = (self._config.display_name
                        or pipeline.pipeline_info.pipeline_name)
        pipeline_spec = pipeline_builder.PipelineBuilder(
            tfx_pipeline=pipeline,
            default_image=self._config.default_image,
            default_commands=self._config.default_commands).build()
        pipeline_spec.sdk_version = 'tfx-{}'.format(version.__version__)
        pipeline_spec.schema_version = _SCHEMA_VERSION
        runtime_config = pipeline_builder.RuntimeConfigBuilder(
            pipeline_info=pipeline.pipeline_info,
            parameter_values=parameter_values).build()
        with telemetry_utils.scoped_labels(
            {telemetry_utils.LABEL_TFX_RUNNER: 'kubeflow_v2'}):
            result = pipeline_pb2.PipelineJob(
                display_name=display_name
                or pipeline.pipeline_info.pipeline_name,
                labels=telemetry_utils.get_labels_dict(),
                runtime_config=runtime_config)
        result.pipeline_spec.update(json_format.MessageToDict(pipeline_spec))
        pipeline_json_dict = json_format.MessageToDict(result)
        if write_out:
            if fileio.exists(
                    self._output_dir) and not fileio.isdir(self._output_dir):
                raise RuntimeError('Output path: %s is pointed to a file.' %
                                   self._output_dir)
            if not fileio.exists(self._output_dir):
                fileio.makedirs(self._output_dir)

            with fileio.open(
                    os.path.join(self._output_dir, self._output_filename),
                    'wb') as f:
                f.write(json.dumps(pipeline_json_dict, sort_keys=True))

        return pipeline_json_dict
コード例 #4
0
ファイル: kubeflow_v2_handler.py プロジェクト: konny0311/tfx
    def _extract_pipeline_args(self) -> Dict[Text, Any]:
        """Get pipeline args from the DSL by compiling the pipeline.

    Returns:
      Python dictionary with pipeline details extracted from DSL.

    Raises:
      RuntimeError: when the given pipeline arg file location is occupied.
    """
        pipeline_dsl_path = self.flags_dict[labels.PIPELINE_DSL_PATH]

        if os.path.isdir(pipeline_dsl_path):
            sys.exit('Provide a valid dsl file path.')

        # Create an environment for subprocess.
        temp_env = os.environ.copy()

        # We don't need image name and project ID for extracting pipeline info,
        # so they can be optional.
        runner_env = {
            kubeflow_labels.TFX_IMAGE_ENV:
            self.flags_dict.get(kubeflow_labels.TFX_IMAGE_ENV, ''),
            kubeflow_labels.GCP_PROJECT_ID_ENV:
            self.flags_dict.get(kubeflow_labels.GCP_PROJECT_ID_ENV, ''),
        }

        temp_env.update(runner_env)

        # Run pipeline dsl. Note that here because we don't have RUN_FLAG_ENV
        # the actual execution won't be triggered. Instead the DSL will output a
        # compiled pipeline spec.
        self._subprocess_call(command=[sys.executable, pipeline_dsl_path],
                              env=temp_env)

        # Only import pipeline_pb2 when needed to guard CLI dependency.
        from tfx.orchestration.kubeflow.v2.proto import pipeline_pb2  # pylint: disable=g-import-not-at-top

        # Extract the needed information from compiled pipeline spec.
        job_message = pipeline_pb2.PipelineJob()
        io_utils.parse_json_file(file_name=os.path.join(
            os.getcwd(), _PIPELINE_SPEC_FILE),
                                 message=job_message)

        pipeline_spec_pb = json_format.ParseDict(job_message.pipeline_spec,
                                                 pipeline_pb2.PipelineSpec())

        pipeline_name = pipeline_spec_pb.pipeline_info.name
        pipeline_args = {
            'pipeline_name': pipeline_name,
            'pipeline_root': job_message.runtime_config.gcs_output_directory
        }

        return pipeline_args