예제 #1
0
    def get_schema(self):
        patcher = kubeflow_dag_runner_patcher.KubeflowDagRunnerPatcher(
            call_real_run=False)
        context = self.execute_dsl(patcher)

        self._read_schema_from_pipeline_root(context[patcher.PIPELINE_NAME],
                                             context[patcher.PIPELINE_ROOT])
예제 #2
0
    def compile_pipeline(self) -> None:
        """Compiles pipeline in Kubeflow.

    Returns:
      pipeline_args: python dictionary with pipeline details extracted from DSL.
    """
        patcher = kubeflow_dag_runner_patcher.KubeflowDagRunnerPatcher(
            call_real_run=True, use_temporary_output_file=False)
        context = self.execute_dsl(patcher)

        click.echo('Pipeline compiled successfully.')
        click.echo(
            f'Pipeline package path: {context[patcher.OUTPUT_FILE_PATH]}')
예제 #3
0
 def testPatcherWithOutputFile(self):
     output_filename = 'foo.tar.gz'
     patcher = kubeflow_dag_runner_patcher.KubeflowDagRunnerPatcher(
         call_real_run=False,
         build_image_fn=None,
         use_temporary_output_file=True)
     runner = kubeflow_dag_runner.KubeflowDagRunner(
         output_filename=output_filename)
     pipeline = tfx_pipeline.Pipeline('dummy', 'dummy_root')
     with patcher.patch() as context:
         runner.run(pipeline)
     self.assertFalse(context[patcher.USE_TEMPORARY_OUTPUT_FILE])
     self.assertEqual(os.path.basename(context[patcher.OUTPUT_FILE_PATH]),
                      output_filename)
     self.assertEqual(runner._output_filename, output_filename)
예제 #4
0
    def testPatcher(self):
        given_image_name = 'foo/bar'
        built_image_name = 'foo/bar@sha256:1234567890'

        mock_build_image_fn = mock.MagicMock(return_value=built_image_name)
        patcher = kubeflow_dag_runner_patcher.KubeflowDagRunnerPatcher(
            call_real_run=True,
            build_image_fn=mock_build_image_fn,
            use_temporary_output_file=True)
        runner_config = kubeflow_dag_runner.KubeflowDagRunnerConfig(
            tfx_image=given_image_name)
        runner = kubeflow_dag_runner.KubeflowDagRunner(config=runner_config)
        pipeline = tfx_pipeline.Pipeline('dummy', 'dummy_root')
        with patcher.patch() as context:
            runner.run(pipeline)
        self.assertTrue(context[patcher.USE_TEMPORARY_OUTPUT_FILE])
        self.assertIn(patcher.OUTPUT_FILE_PATH, context)

        mock_build_image_fn.assert_called_once_with(given_image_name)
        self.assertEqual(runner_config.tfx_image, built_image_name)
예제 #5
0
    def create_pipeline(self, update: bool = False) -> None:
        """Creates or updates a pipeline in Kubeflow.

    Args:
      update: set as true to update pipeline.
    """

        if self.flags_dict.get(labels.BUILD_IMAGE):
            build_image_fn = functools.partial(create_container_image,
                                               base_image=self.flags_dict.get(
                                                   labels.BASE_IMAGE))
        else:
            build_image_fn = None
            if os.path.exists('build.yaml'):
                click.echo(
                    '[Warning] TFX doesn\'t depend on skaffold anymore and you can '
                    'delete the auto-genrated build.yaml file. TFX will NOT build a '
                    'container even if build.yaml file exists. Use --build-image flag '
                    'to trigger an image build when creating or updating a pipeline.',
                    err=True)
        patcher = kubeflow_dag_runner_patcher.KubeflowDagRunnerPatcher(
            call_real_run=True,
            use_temporary_output_file=True,
            build_image_fn=build_image_fn)
        context = self.execute_dsl(patcher)
        pipeline_name = context[patcher.PIPELINE_NAME]
        pipeline_package_path = context[patcher.OUTPUT_FILE_PATH]

        self._save_pipeline(pipeline_name,
                            pipeline_package_path,
                            update=update)

        if context[patcher.USE_TEMPORARY_OUTPUT_FILE]:
            os.remove(pipeline_package_path)

        if update:
            click.echo(
                'Pipeline "{}" updated successfully.'.format(pipeline_name))
        else:
            click.echo(
                'Pipeline "{}" created successfully.'.format(pipeline_name))