def testUpdatePipeline(self): # First create pipeline with test_pipeline.py pipeline_path_1 = os.path.join(self.chicago_taxi_pipeline_dir, 'test_pipeline_kubeflow_1.py') flags_dict_1 = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: pipeline_path_1, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, labels.PIPELINE_PACKAGE_PATH: self.pipeline_package_path } handler = kubeflow_handler.KubeflowHandler(flags_dict_1) handler.create_pipeline() # Update test_pipeline and run update_pipeline pipeline_path_2 = os.path.join(self.chicago_taxi_pipeline_dir, 'test_pipeline_kubeflow_2.py') flags_dict_2 = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: pipeline_path_2, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, labels.PIPELINE_PACKAGE_PATH: self.pipeline_package_path } handler = kubeflow_handler.KubeflowHandler(flags_dict_2) handler_pipeline_path = os.path.join( handler._handler_home_dir, self.pipeline_args[labels.PIPELINE_NAME], '') self.assertTrue(tf.io.gfile.exists(handler_pipeline_path)) handler.update_pipeline() self.assertTrue( tf.io.gfile.exists( os.path.join(handler_pipeline_path, 'pipeline_args.json')))
def testDeletePipeline(self): # Create pipeline. flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, labels.PIPELINE_PACKAGE_PATH: self.pipeline_package_path } handler = kubeflow_handler.KubeflowHandler(flags_dict) handler.create_pipeline() # Delete pipeline. flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_NAME: self.pipeline_name, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, } handler = kubeflow_handler.KubeflowHandler(flags_dict) handler.delete_pipeline() handler_pipeline_path = os.path.join( handler._handler_home_dir, self.pipeline_args[labels.PIPELINE_NAME], '') self.assertFalse(tf.io.gfile.exists(handler_pipeline_path))
def testListRuns(self): # Create a pipeline. flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, labels.PIPELINE_PACKAGE_PATH: self.pipeline_package_path } handler = kubeflow_handler.KubeflowHandler(flags_dict) handler.create_pipeline() # List pipelines. flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_NAME: self.pipeline_name, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, } handler = kubeflow_handler.KubeflowHandler(flags_dict) with self.captureWritesToStream(sys.stdout) as captured: handler.list_runs() self.assertIn('pipeline_name', captured.contents())
def testGetSchema(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, labels.PIPELINE_PACKAGE_PATH: self.pipeline_package_path } handler = kubeflow_handler.KubeflowHandler(flags_dict) handler.create_pipeline() flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_NAME: self.pipeline_name, } # No pipeline root handler = kubeflow_handler.KubeflowHandler(flags_dict) with self.assertRaises(SystemExit) as err: handler.get_schema() self.assertEqual( str(err.exception), 'Create a run before inferring schema. If pipeline is already running, then wait for it to successfully finish.' ) # No SchemaGen output. fileio.makedirs(self.pipeline_root) with self.assertRaises(SystemExit) as err: handler.get_schema() self.assertEqual( str(err.exception), 'Either SchemaGen component does not exist or pipeline is still running. If pipeline is running, then wait for it to successfully finish.' ) # Successful pipeline run. # Create fake schema in pipeline root. component_output_dir = os.path.join(self.pipeline_root, 'SchemaGen') schema_path = base_driver._generate_output_uri( # pylint: disable=protected-access component_output_dir, 'schema', 3) fileio.makedirs(schema_path) with open(os.path.join(schema_path, 'schema.pbtxt'), 'w') as f: f.write('SCHEMA') with self.captureWritesToStream(sys.stdout) as captured: handler.get_schema() curr_dir_path = os.path.join(os.getcwd(), 'schema.pbtxt') self.assertIn('Path to schema: {}'.format(curr_dir_path), captured.contents()) self.assertIn( '*********SCHEMA FOR {}**********'.format( self.pipeline_name.upper()), captured.contents()) self.assertTrue(fileio.exists(curr_dir_path))
def detect_handler(flags_dict: Dict[Text, Any]) -> base_handler.BaseHandler: """Detect handler from the environment. Details: When the engine flag is set to 'auto', this method first finds all the packages in the local environment. The environment is first checked for multiple orchestrators and if true the user must rerun the command with required engine. If only one orchestrator is present, the engine is set to that. Args: flags_dict: A dictionary containing the flags of a command. Returns: Corrosponding Handler object. """ packages_list = str(subprocess.check_output(['pip', 'freeze', '--local'])) if labels.AIRFLOW_PACKAGE_NAME in packages_list and labels.KUBEFLOW_PACKAGE_NAME in packages_list: sys.exit('Multiple orchestrators found. Choose one using --engine flag.') if labels.AIRFLOW_PACKAGE_NAME in packages_list: click.echo('Detected Airflow.') flags_dict[labels.ENGINE_FLAG] = 'airflow' from tfx.tools.cli.handler import airflow_handler # pylint: disable=g-import-not-at-top return airflow_handler.AirflowHandler(flags_dict) elif labels.KUBEFLOW_PACKAGE_NAME in packages_list: click.echo('Detected Kubeflow.') flags_dict[labels.ENGINE_FLAG] = 'kubeflow' from tfx.tools.cli.handler import kubeflow_handler # pylint: disable=g-import-not-at-top return kubeflow_handler.KubeflowHandler(flags_dict) # TODO(b/132286477):Update to beam runner later. else: click.echo('Detected Beam.') flags_dict[labels.ENGINE_FLAG] = 'beam' from tfx.tools.cli.handler import beam_handler # pylint: disable=g-import-not-at-top return beam_handler.BeamHandler(flags_dict)
def create_handler(flags_dict: Dict[Text, Any]) -> base_handler.BaseHandler: """Retrieve handler from the environment using the --engine flag. Args: flags_dict: A dictionary containing the flags of a command. Raises: RuntimeError: When engine is not supported by TFX. Returns: Corresponding Handler object. """ engine = flags_dict[labels.ENGINE_FLAG] packages_list = str(subprocess.check_output(['pip', 'freeze', '--local'])) if engine == 'airflow': if labels.AIRFLOW_PACKAGE_NAME not in packages_list: sys.exit('Airflow not found.') from tfx.tools.cli.handler import airflow_handler # pylint: disable=g-import-not-at-top return airflow_handler.AirflowHandler(flags_dict) elif engine == 'kubeflow': if labels.KUBEFLOW_PACKAGE_NAME not in packages_list: sys.exit('Kubeflow not found.') from tfx.tools.cli.handler import kubeflow_handler # pylint: disable=g-import-not-at-top return kubeflow_handler.KubeflowHandler(flags_dict) elif engine == 'beam': from tfx.tools.cli.handler import beam_handler # pylint: disable=g-import-not-at-top return beam_handler.BeamHandler(flags_dict) elif engine == 'local': from tfx.tools.cli.handler import local_handler # pylint: disable=g-import-not-at-top return local_handler.LocalHandler(flags_dict) elif engine == 'auto': return detect_handler(flags_dict) else: raise RuntimeError('Engine {} is not supported.'.format(engine))
def testCheckPipelinePackagePathWrongPath(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, labels.PIPELINE_PACKAGE_PATH: os.path.join(self.chicago_taxi_pipeline_dir, '{}.tar.gz'.format(self.pipeline_name)) } handler = kubeflow_handler.KubeflowHandler(flags_dict) pipeline_args = handler._extract_pipeline_args() with self.assertRaises(SystemExit) as err: handler._check_pipeline_package_path( pipeline_args[labels.PIPELINE_NAME]) self.assertEqual( str(err.exception), 'Pipeline package not found at {}. When --package_path is unset, it will try to find the workflow file, "<pipeline_name>.tar.gz" in the current directory.' .format(flags_dict[labels.PIPELINE_PACKAGE_PATH]))
def testCompilePipeline(self): handler = kubeflow_handler.KubeflowHandler( self.flags_with_package_path) with self.captureWritesToStream(sys.stdout) as captured: handler.compile_pipeline() self.assertIn('Pipeline compiled successfully', captured.contents()) self.assertIn('Pipeline package path', captured.contents())
def create_handler(flags_dict: Dict[Text, Any]) -> base_handler.BaseHandler: """Retrieve handler from the environment using the --engine flag. Args: flags_dict: A dictionary containing the flags of a command. Raises: RuntimeError: When engine is not supported by TFX. Returns: Corresponding Handler object. """ engine = flags_dict[labels.ENGINE_FLAG] if engine == 'airflow': from tfx.tools.cli.handler import airflow_handler # pylint: disable=g-import-not-at-top return airflow_handler.AirflowHandler(flags_dict) elif engine == 'kubeflow': from tfx.tools.cli.handler import kubeflow_handler # pylint: disable=g-import-not-at-top return kubeflow_handler.KubeflowHandler(flags_dict) elif engine == 'beam': from tfx.tools.cli.handler import beam_handler # pylint: disable=g-import-not-at-top return beam_handler.BeamHandler(flags_dict) elif engine == 'auto': return detect_handler(flags_dict) else: raise RuntimeError('Engine {} is not supported.'.format(engine))
def testListRunsNoPipeline(self): handler = kubeflow_handler.KubeflowHandler(self.flags_with_name) self.mock_client.get_pipeline_id.return_value = None with self.assertRaises(SystemExit) as err: handler.list_runs() self.assertIn(f'Cannot find pipeline "{self.pipeline_name}".', str(err.exception))
def testCreateRun(self): handler = kubeflow_handler.KubeflowHandler(self.flags_with_name) with self.captureWritesToStream(sys.stdout) as captured: handler.create_run() self.assertIn('Run created for pipeline: ', captured.contents()) self.mock_client.run_pipeline.assert_called_once_with( experiment_id=self.experiment_id, job_name=self.pipeline_name, version_id=self.pipeline_version_id)
def testDeletePipeline(self): handler = kubeflow_handler.KubeflowHandler(self.flags_with_name) handler.delete_pipeline() self.mock_client.delete_pipeline.assert_called_once_with( self.pipeline_id) self.mock_client._experiment_api.delete_experiment.assert_called_once_with( self.experiment_id)
def testCreatePipelineExistentPipeline(self): handler = kubeflow_handler.KubeflowHandler(self.flags_with_dsl_path) # 'the_pipeline_id' will be returned. with self.assertRaises(SystemExit) as err: handler.create_pipeline() self.assertIn( f'Pipeline "{self.pipeline_args[labels.PIPELINE_NAME]}" already exists.', str(err.exception)) self.mock_client.upload_pipeline.assert_not_called()
def testCreateRunNoPipeline(self): handler = kubeflow_handler.KubeflowHandler(self.flags_with_name) self.mock_client.get_pipeline_id.return_value = None with self.assertRaises(SystemExit) as err: handler.create_run() self.assertIn(f'Cannot find pipeline "{self.pipeline_name}".', str(err.exception)) self.mock_client.run_pipeline.assert_not_called()
def testCheckPipelinePackagePathDefaultPath(self): flags = dict(self.flags_with_dsl_path) flags[labels.PIPELINE_PACKAGE_PATH] = None handler = kubeflow_handler.KubeflowHandler(flags) pipeline_args = handler._extract_pipeline_args() handler._check_pipeline_package_path( pipeline_args[labels.PIPELINE_NAME]) self.assertEqual( handler.flags_dict[labels.PIPELINE_PACKAGE_PATH], os.path.abspath('{}.tar.gz'.format( pipeline_args[labels.PIPELINE_NAME])))
def testCompilePipeline(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path, labels.PIPELINE_PACKAGE_PATH: self.pipeline_package_path } handler = kubeflow_handler.KubeflowHandler(flags_dict) with self.captureWritesToStream(sys.stdout) as captured: handler.compile_pipeline() self.assertIn('Pipeline compiled successfully', captured.contents()) self.assertIn('Pipeline package path', captured.contents())
def testDeletePipelineNonExistentPipeline(self): handler = kubeflow_handler.KubeflowHandler(self.flags_with_name) self.mock_client.get_pipeline_id.return_value = None with self.assertRaises(SystemExit) as err: handler.delete_pipeline() self.assertIn(f'Cannot find pipeline "{self.pipeline_name}".', str(err.exception)) self.mock_client.delete_pipeline.assert_not_called() self.mock_client._experiment_api.delete_experiment.assert_not_called()
def testCreateRun(self): self.mock_client.run_pipeline.return_value = _MockRunResponse( self.pipeline_name, '1', 'Success', datetime.datetime.now()) handler = kubeflow_handler.KubeflowHandler(self.flags_with_name) with self.captureWritesToStream(sys.stdout) as captured: handler.create_run() self.assertIn('Run created for pipeline: ', captured.contents()) self.mock_client.run_pipeline.assert_called_once_with( experiment_id=self.experiment_id, job_name=self.pipeline_name, version_id=self.pipeline_version_id)
def testUpdatePipelineNoPipeline(self): handler = kubeflow_handler.KubeflowHandler(self.flags_with_dsl_path) self.mock_client.get_pipeline_id.return_value = None with self.assertRaises(SystemExit) as err: handler.update_pipeline() self.assertIn(f'Cannot find pipeline "{self.pipeline_name}".', str(err.exception)) self.mock_client.upload_pipeline.assert_not_called() self.mock_client.upload_pipeline_version.assert_not_called()
def testUpdatePipeline(self): handler = kubeflow_handler.KubeflowHandler(self.flags_with_dsl_path) # Update test_pipeline and run update_pipeline handler.update_pipeline() self.mock_client.upload_pipeline.assert_not_called() self.mock_client.create_experiment.assert_not_called() self.mock_client.upload_pipeline_version.assert_called_once_with( pipeline_package_path=mock.ANY, pipeline_version_name=mock.ANY, pipeline_id=self.pipeline_id)
def testCreatePipeline(self): handler = kubeflow_handler.KubeflowHandler(self.flags_with_dsl_path) self.mock_client.get_pipeline_id.return_value = None self.mock_client.upload_pipeline.return_value.id = 'new_pipeline_id' handler.create_pipeline() self.mock_client.upload_pipeline.assert_called_once_with( pipeline_package_path=mock.ANY, pipeline_name=self.pipeline_name) self.mock_client.create_experiment.assert_called_once_with( self.pipeline_name) self.mock_client.upload_pipeline_version.assert_not_called()
def test_compile_pipeline(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, labels.PIPELINE_PACKAGE_PATH: self.pipeline_package_path } handler = kubeflow_handler.KubeflowHandler(flags_dict) with self.captureWritesToStream(sys.stdout) as captured: handler.compile_pipeline() self.assertIn('Pipeline compiled successfully', captured.contents()) self.assertIn('Pipeline package path', captured.contents())
def test_save_pipeline(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, labels.PIPELINE_PACKAGE_PATH: self.pipeline_package_path } handler = kubeflow_handler.KubeflowHandler(flags_dict) handler._save_pipeline(self.pipeline_args) handler_pipeline_path = handler._get_handler_pipeline_path( self.pipeline_args[labels.PIPELINE_NAME]) self.assertTrue(os.path.join(handler_pipeline_path, 'pipeline_args.json'))
def testListRunsNoPipeline(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_NAME: self.pipeline_name, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, } handler = kubeflow_handler.KubeflowHandler(flags_dict) with self.assertRaises(SystemExit) as err: handler.list_runs() self.assertEqual( str(err.exception), 'Pipeline "{}" does not exist.'.format( flags_dict[labels.PIPELINE_NAME]))
def testCheckPipelinePackagePathWrongPath(self): flags_dict = dict(self.flags_with_dsl_path) flags_dict[labels.PIPELINE_PACKAGE_PATH] = os.path.join( self.chicago_taxi_pipeline_dir, '{}.tar.gz'.format(self.pipeline_name)) handler = kubeflow_handler.KubeflowHandler(flags_dict) pipeline_args = handler._extract_pipeline_args() with self.assertRaises(SystemExit) as err: handler._check_pipeline_package_path( pipeline_args[labels.PIPELINE_NAME]) self.assertEqual( str(err.exception), 'Pipeline package not found at {}. When --package_path is unset, it will try to find the workflow file, "<pipeline_name>.tar.gz" in the current directory.' .format(flags_dict[labels.PIPELINE_PACKAGE_PATH]))
def test_create_pipeline(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, labels.PIPELINE_PACKAGE_PATH: self.pipeline_package_path } handler = kubeflow_handler.KubeflowHandler(flags_dict) handler_pipeline_path = handler._get_handler_pipeline_path( self.pipeline_args[labels.PIPELINE_NAME]) self.assertFalse(tf.io.gfile.exists(handler_pipeline_path)) handler.create_pipeline() self.assertTrue(tf.io.gfile.exists(handler_pipeline_path))
def testListRuns(self): handler = kubeflow_handler.KubeflowHandler(self.flags_with_name) self.mock_client.list_runs.return_value.runs = [ _MockRunResponse(self.pipeline_name, '1', 'Success', datetime.datetime.now()), _MockRunResponse(self.pipeline_name, '2', 'Failed', datetime.datetime.now()), ] with self.captureWritesToStream(sys.stdout) as captured: handler.list_runs() self.mock_client.list_runs.assert_called_once_with( experiment_id=self.experiment_id) self.assertIn('pipeline_name', captured.contents())
def testCheckPipelinePackagePathDefaultPath(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, labels.PIPELINE_PACKAGE_PATH: None } handler = kubeflow_handler.KubeflowHandler(flags_dict) pipeline_args = handler._extract_pipeline_args() handler._check_pipeline_package_path(pipeline_args[labels.PIPELINE_NAME]) self.assertEqual( handler.flags_dict[labels.PIPELINE_PACKAGE_PATH], os.path.join(os.getcwd(), '{}.tar.gz'.format(pipeline_args[labels.PIPELINE_NAME])))
def testUpdatePipelineNoPipeline(self): # Update pipeline without creating one. flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, labels.PIPELINE_PACKAGE_PATH: self.pipeline_package_path } handler = kubeflow_handler.KubeflowHandler(flags_dict) with self.assertRaises(SystemExit) as err: handler.update_pipeline() self.assertEqual( str(err.exception), 'Pipeline "{}" does not exist.'.format( self.pipeline_args[labels.PIPELINE_NAME]))
def testGetSchema(self, mock_extract_pipeline_args): temp_pipeline_root = os.path.join(self.tmp_dir, 'pipeline_root') mock_extract_pipeline_args.return_value = { labels.PIPELINE_NAME: self.pipeline_name, labels.PIPELINE_ROOT: temp_pipeline_root } flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_NAME: self.pipeline_name, } handler = kubeflow_handler.KubeflowHandler(flags_dict) # No pipeline root with self.assertRaises(SystemExit) as err: handler.get_schema() self.assertEqual( str(err.exception), 'Create a run before inferring schema. If pipeline is already running, then wait for it to successfully finish.' ) # No SchemaGen output. fileio.makedirs(temp_pipeline_root) with self.assertRaises(SystemExit) as err: handler.get_schema() self.assertEqual( str(err.exception), 'Either SchemaGen component does not exist or pipeline is still running. If pipeline is running, then wait for it to successfully finish.' ) # Successful pipeline run. # Create fake schema in pipeline root. component_output_dir = os.path.join(temp_pipeline_root, 'SchemaGen') schema_path = base_driver._generate_output_uri( # pylint: disable=protected-access component_output_dir, 'schema', 3) fileio.makedirs(schema_path) with open(os.path.join(schema_path, 'schema.pbtxt'), 'w') as f: f.write('SCHEMA') with self.captureWritesToStream(sys.stdout) as captured: handler.get_schema() curr_dir_path = os.path.join(os.getcwd(), 'schema.pbtxt') self.assertIn('Path to schema: {}'.format(curr_dir_path), captured.contents()) self.assertIn( '*********SCHEMA FOR {}**********'.format( self.pipeline_name.upper()), captured.contents()) self.assertTrue(fileio.exists(curr_dir_path))