示例#1
0
    def testCreateRun(self, mock_pipeline_job, mock_init):
        flags_dict = {
            labels.ENGINE_FLAG: self.engine,
            labels.PIPELINE_NAME: self.pipeline_name,
            labels.GCP_PROJECT_ID: _TEST_PROJECT_1,
            labels.GCP_REGION: _TEST_REGION,
            labels.RUNTIME_PARAMETER: self.runtime_parameter,
        }
        # TODO(b/198114641): Delete following override after upgrading source code
        # to aiplatform>=1.3.
        mock_pipeline_job.return_value.wait_for_resource_creation = mock.MagicMock(
        )

        handler = vertex_handler.VertexHandler(flags_dict)
        handler.create_run()

        mock_init.assert_called_once_with(project=_TEST_PROJECT_1,
                                          location=_TEST_REGION)
        mock_pipeline_job.assert_called_once_with(
            display_name=_TEST_PIPELINE_NAME,
            template_path=handler._get_pipeline_definition_path(
                _TEST_PIPELINE_NAME),
            parameter_values={
                'a': '1',
                'b': '2'
            })
示例#2
0
 def testCreatePipeline(self):
     flags_dict = {
         labels.ENGINE_FLAG: self.engine,
         labels.PIPELINE_DSL_PATH: self.pipeline_path
     }
     handler = vertex_handler.VertexHandler(flags_dict)
     handler.create_pipeline()
     self.assertTrue(
         fileio.exists(
             handler._get_pipeline_definition_path(self.pipeline_name)))
示例#3
0
    def testDeletePipeline(self):
        # First create a pipeline.
        flags_dict = {
            labels.ENGINE_FLAG: self.engine,
            labels.PIPELINE_DSL_PATH: self.pipeline_path
        }
        handler = vertex_handler.VertexHandler(flags_dict)
        handler.create_pipeline()

        # Now delete the pipeline created aand check if pipeline folder is deleted.
        flags_dict = {
            labels.ENGINE_FLAG: self.engine,
            labels.PIPELINE_NAME: self.pipeline_name
        }
        handler = vertex_handler.VertexHandler(flags_dict)
        handler.delete_pipeline()
        handler_pipeline_path = os.path.join(handler._handler_home_dir,
                                             self.pipeline_name)
        self.assertFalse(fileio.exists(handler_pipeline_path))
示例#4
0
 def testCompilePipeline(self):
     flags_dict = {
         labels.ENGINE_FLAG: self.engine,
         labels.PIPELINE_DSL_PATH: self.pipeline_path,
     }
     handler = vertex_handler.VertexHandler(flags_dict)
     with self.captureWritesToStream(sys.stdout) as captured:
         handler.compile_pipeline()
     self.assertIn(f'Pipeline {self.pipeline_name} compiled successfully',
                   captured.contents())
示例#5
0
 def testDeletePipelineNonExistentPipeline(self):
     flags_dict = {
         labels.ENGINE_FLAG: self.engine,
         labels.PIPELINE_NAME: self.pipeline_name
     }
     handler = vertex_handler.VertexHandler(flags_dict)
     with self.assertRaises(SystemExit) as err:
         handler.delete_pipeline()
     self.assertEqual(
         str(err.exception), 'Pipeline "{}" does not exist.'.format(
             flags_dict[labels.PIPELINE_NAME]))
示例#6
0
 def testUpdatePipelineNoPipeline(self):
     # Update pipeline without creating one.
     flags_dict = {
         labels.ENGINE_FLAG: self.engine,
         labels.PIPELINE_DSL_PATH: self.pipeline_path
     }
     handler = vertex_handler.VertexHandler(flags_dict)
     with self.assertRaises(SystemExit) as err:
         handler.update_pipeline()
     self.assertEqual(
         str(err.exception),
         'Pipeline "{}" does not exist.'.format(self.pipeline_name))
示例#7
0
 def testCreatePipelineExistentPipeline(self):
     flags_dict = {
         labels.ENGINE_FLAG: self.engine,
         labels.PIPELINE_DSL_PATH: self.pipeline_path
     }
     handler = vertex_handler.VertexHandler(flags_dict)
     handler.create_pipeline()
     # Run create_pipeline again to test.
     with self.assertRaises(SystemExit) as err:
         handler.create_pipeline()
     self.assertEqual(
         str(err.exception),
         'Pipeline "{}" already exists.'.format(self.pipeline_name))
示例#8
0
    def testUpdatePipeline(self):
        # First create pipeline with test_pipeline.py
        pipeline_path_1 = os.path.join(self.chicago_taxi_pipeline_dir,
                                       'test_pipeline_kubeflow_v2_1.py')
        flags_dict_1 = {
            labels.ENGINE_FLAG: self.engine,
            labels.PIPELINE_DSL_PATH: pipeline_path_1
        }
        handler = vertex_handler.VertexHandler(flags_dict_1)
        handler.create_pipeline()

        # Update test_pipeline and run update_pipeline
        pipeline_path_2 = os.path.join(self.chicago_taxi_pipeline_dir,
                                       'test_pipeline_kubeflow_v2_2.py')
        flags_dict_2 = {
            labels.ENGINE_FLAG: self.engine,
            labels.PIPELINE_DSL_PATH: pipeline_path_2
        }
        handler = vertex_handler.VertexHandler(flags_dict_2)
        handler.update_pipeline()
        self.assertTrue(
            fileio.exists(
                handler._get_pipeline_definition_path(self.pipeline_name)))
示例#9
0
    def testListPipelinesNonEmpty(self):
        # First create two pipelines in the dags folder.
        handler_pipeline_path_1 = os.path.join(os.environ['VERTEX_HOME'],
                                               'pipeline_1')
        handler_pipeline_path_2 = os.path.join(os.environ['VERTEX_HOME'],
                                               'pipeline_2')
        fileio.makedirs(handler_pipeline_path_1)
        fileio.makedirs(handler_pipeline_path_2)

        # Now, list the pipelines
        flags_dict = {labels.ENGINE_FLAG: labels.VERTEX_ENGINE}
        handler = vertex_handler.VertexHandler(flags_dict)

        with self.captureWritesToStream(sys.stdout) as captured:
            handler.list_pipelines()
        self.assertIn('pipeline_1', captured.contents())
        self.assertIn('pipeline_2', captured.contents())
示例#10
0
def create_handler(flags_dict: Dict[Text, Any]) -> base_handler.BaseHandler:
    """Retrieve handler from the environment using the --engine flag.

  Args:
    flags_dict: A dictionary containing the flags of a command.

  Raises:
    RuntimeError: When engine is not supported by TFX.

  Returns:
    Corresponding Handler object.
  """
    engine = flags_dict[labels.ENGINE_FLAG]
    packages_list = str(subprocess.check_output(['pip', 'freeze', '--local']))
    if engine == 'airflow':
        if labels.AIRFLOW_PACKAGE_NAME not in packages_list:
            sys.exit('Airflow not found.')
        from tfx.tools.cli.handler import airflow_handler  # pylint: disable=g-import-not-at-top
        return airflow_handler.AirflowHandler(flags_dict)
    elif engine == 'kubeflow':
        if labels.KUBEFLOW_PACKAGE_NAME not in packages_list:
            sys.exit('Kubeflow not found.')
        from tfx.tools.cli.handler import kubeflow_handler  # pylint: disable=g-import-not-at-top
        return kubeflow_handler.KubeflowHandler(flags_dict)
    elif engine == 'beam':
        from tfx.tools.cli.handler import beam_handler  # pylint: disable=g-import-not-at-top
        return beam_handler.BeamHandler(flags_dict)
    elif engine == 'local':
        from tfx.tools.cli.handler import local_handler  # pylint: disable=g-import-not-at-top
        return local_handler.LocalHandler(flags_dict)
    elif engine == 'vertex':
        if labels.KUBEFLOW_PACKAGE_NAME not in packages_list:
            sys.exit('`kfp` python pacakge is required for Vertex.')
        from tfx.tools.cli.handler import vertex_handler  # pylint: disable=g-import-not-at-top
        return vertex_handler.VertexHandler(flags_dict)
    elif engine == 'auto':
        return detect_handler(flags_dict)
    else:
        raise RuntimeError('Engine {} is not supported.'.format(engine))
示例#11
0
 def testListPipelinesEmpty(self):
     flags_dict = {labels.ENGINE_FLAG: labels.VERTEX_ENGINE}
     handler = vertex_handler.VertexHandler(flags_dict)
     with self.captureWritesToStream(sys.stdout) as captured:
         handler.list_pipelines()
     self.assertIn('No pipelines to display.', captured.contents())