def testPipeline(self): self._copyTemplate() # Uncomment all variables in config. self._uncommentMultiLineVariables( os.path.join('pipeline', 'configs.py'), [ 'GOOGLE_CLOUD_REGION', 'BIG_QUERY_WITH_DIRECT_RUNNER_BEAM_PIPELINE_ARGS', 'BIG_QUERY_QUERY', 'DATAFLOW_BEAM_PIPELINE_ARGS', 'GCP_AI_PLATFORM_TRAINING_ARGS', 'GCP_AI_PLATFORM_SERVING_ARGS' ]) # Prepare data. self._prepare_data() self._replaceFileContent('kubeflow_v2_dag_runner.py', [ ('_DATA_PATH = \'gs://{}/tfx-template/data/\'.' 'format(configs.GCS_BUCKET_NAME)', '_DATA_PATH = \'gs://{{}}/{}/{}\'.format(configs.GCS_BUCKET_NAME)' .format(self._DATA_DIRECTORY_NAME, self._pipeline_name)), ]) # Create a pipeline with only one component. self._create_pipeline() # Extract the compiled pipeline spec. kubeflow_v2_pb = pipeline_pb2.PipelineJob() io_utils.parse_json_file(file_name=os.path.join( os.getcwd(), 'pipeline.json'), message=kubeflow_v2_pb) # There should be one step in the compiled pipeline. self.assertLen(kubeflow_v2_pb.pipeline_spec['tasks'], 1)
def _mock_subprocess_call(cmd: Sequence[Optional[Text]], env: Mapping[Text, Text]) -> int: """Mocks the subprocess call.""" assert len(cmd) == 2, 'Unexpected number of commands: {}'.format(cmd) del env dsl_path = cmd[1] if dsl_path.endswith('test_pipeline_bad.py'): sys.exit(1) if not dsl_path.endswith( 'test_pipeline_1.py') and not dsl_path.endswith( 'test_pipeline_2.py'): raise ValueError('Unexpected dsl path: {}'.format(dsl_path)) spec_pb = pipeline_pb2.PipelineSpec( pipeline_info=pipeline_pb2.PipelineInfo(name='chicago_taxi_kubeflow')) runtime_pb = pipeline_pb2.PipelineJob.RuntimeConfig( gcs_output_directory=os.path.join(os.environ['HOME'], 'tfx', 'pipelines', 'chicago_taxi_kubeflow')) job_pb = pipeline_pb2.PipelineJob(runtime_config=runtime_pb) job_pb.pipeline_spec.update(json_format.MessageToDict(spec_pb)) io_utils.write_string_file( file_name='pipeline.json', string_value=json_format.MessageToJson(message=job_pb, sort_keys=True)) return 0
def run(self, pipeline: tfx_pipeline.Pipeline, parameter_values: Optional[Dict[Text, Any]] = None, write_out: Optional[bool] = True) -> Dict[Text, Any]: """Compiles a pipeline DSL object into pipeline file. Args: pipeline: TFX pipeline object. parameter_values: mapping from runtime parameter names to its values. write_out: set to True to actually write out the file to the place designated by output_dir and output_filename. Otherwise return the JSON-serialized pipeline job spec. Returns: Returns the JSON pipeline job spec. Raises: RuntimeError: if trying to write out to a place occupied by an existing file. """ # TODO(b/166343606): Support user-provided labels. # TODO(b/169095387): Deprecate .run() method in favor of the unified API # client. display_name = (self._config.display_name or pipeline.pipeline_info.pipeline_name) pipeline_spec = pipeline_builder.PipelineBuilder( tfx_pipeline=pipeline, default_image=self._config.default_image, default_commands=self._config.default_commands).build() pipeline_spec.sdk_version = 'tfx-{}'.format(version.__version__) pipeline_spec.schema_version = _SCHEMA_VERSION runtime_config = pipeline_builder.RuntimeConfigBuilder( pipeline_info=pipeline.pipeline_info, parameter_values=parameter_values).build() with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_RUNNER: 'kubeflow_v2'}): result = pipeline_pb2.PipelineJob( display_name=display_name or pipeline.pipeline_info.pipeline_name, labels=telemetry_utils.get_labels_dict(), runtime_config=runtime_config) result.pipeline_spec.update(json_format.MessageToDict(pipeline_spec)) pipeline_json_dict = json_format.MessageToDict(result) if write_out: if fileio.exists( self._output_dir) and not fileio.isdir(self._output_dir): raise RuntimeError('Output path: %s is pointed to a file.' % self._output_dir) if not fileio.exists(self._output_dir): fileio.makedirs(self._output_dir) with fileio.open( os.path.join(self._output_dir, self._output_filename), 'wb') as f: f.write(json.dumps(pipeline_json_dict, sort_keys=True)) return pipeline_json_dict
def _extract_pipeline_args(self) -> Dict[Text, Any]: """Get pipeline args from the DSL by compiling the pipeline. Returns: Python dictionary with pipeline details extracted from DSL. Raises: RuntimeError: when the given pipeline arg file location is occupied. """ pipeline_dsl_path = self.flags_dict[labels.PIPELINE_DSL_PATH] if os.path.isdir(pipeline_dsl_path): sys.exit('Provide a valid dsl file path.') # Create an environment for subprocess. temp_env = os.environ.copy() # We don't need image name and project ID for extracting pipeline info, # so they can be optional. runner_env = { kubeflow_labels.TFX_IMAGE_ENV: self.flags_dict.get(kubeflow_labels.TFX_IMAGE_ENV, ''), kubeflow_labels.GCP_PROJECT_ID_ENV: self.flags_dict.get(kubeflow_labels.GCP_PROJECT_ID_ENV, ''), } temp_env.update(runner_env) # Run pipeline dsl. Note that here because we don't have RUN_FLAG_ENV # the actual execution won't be triggered. Instead the DSL will output a # compiled pipeline spec. self._subprocess_call(command=[sys.executable, pipeline_dsl_path], env=temp_env) # Only import pipeline_pb2 when needed to guard CLI dependency. from tfx.orchestration.kubeflow.v2.proto import pipeline_pb2 # pylint: disable=g-import-not-at-top # Extract the needed information from compiled pipeline spec. job_message = pipeline_pb2.PipelineJob() io_utils.parse_json_file(file_name=os.path.join( os.getcwd(), _PIPELINE_SPEC_FILE), message=job_message) pipeline_spec_pb = json_format.ParseDict(job_message.pipeline_spec, pipeline_pb2.PipelineSpec()) pipeline_name = pipeline_spec_pb.pipeline_info.name pipeline_args = { 'pipeline_name': pipeline_name, 'pipeline_root': job_message.runtime_config.gcs_output_directory } return pipeline_args