def _verify_output(self): # Test best hparams. best_hparams_path = os.path.join(self._best_hparams.uri, 'best_hyperparameters.txt') self.assertTrue(fileio.exists(best_hparams_path)) best_hparams_config = json.loads( file_io.read_file_to_string(best_hparams_path)) best_hparams = HyperParameters.from_config(best_hparams_config) self.assertIn(best_hparams.get('learning_rate'), (1e-1, 1e-3)) self.assertBetween(best_hparams.get('num_layers'), 1, 5)
def testSavePipeline(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path } handler = local_handler.LocalHandler(flags_dict) handler._save_pipeline({labels.PIPELINE_NAME: self.pipeline_name}) self.assertTrue( fileio.exists( os.path.join(handler._handler_home_dir, self.pipeline_name)))
def _assert_infra_validator_passed(self, pipeline_name: Text): artifacts = self._get_artifacts_with_type_and_pipeline( type_name='InfraBlessing', pipeline_name=pipeline_name) self.assertGreaterEqual(len(artifacts), 1) for artifact in artifacts: blessed = os.path.join(artifact.uri, 'INFRA_BLESSED') self.assertTrue( fileio.exists(blessed), 'Expected InfraBlessing results cannot be found under path %s for ' 'artifact %s' % (blessed, artifact))
def testDoWithBlessedModel(self): # Create exe properties. exec_properties = { 'blessed_model': os.path.join(self._source_data_dir, 'trainer/blessed'), 'blessed_model_id': 123, 'current_component_id': self.component_id, } # Run executor. model_validator = executor.Executor(self._context) model_validator.Do(self._input_dict, self._output_dict, exec_properties) # Check model validator outputs. self.assertTrue(fileio.exists(os.path.join(self._tmp_dir))) self.assertTrue( fileio.exists( os.path.join(self._blessing.uri, constants.BLESSED_FILE_NAME)))
def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter): m = self.ConverterMock() converter.return_value = m src_model_path = tempfile.mkdtemp() dst_model_path = tempfile.mkdtemp() saved_model_path = os.path.join( src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT) with fileio.open(saved_model_path, 'wb') as f: f.write(six.ensure_binary('saved_model')) assets_dir = os.path.join(src_model_path, tf.saved_model.ASSETS_DIRECTORY) fileio.mkdir(assets_dir) assets_file_path = os.path.join(assets_dir, 'assets_file') with fileio.open(assets_file_path, 'wb') as f: f.write(six.ensure_binary('assets_file')) assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY) fileio.mkdir(assets_extra_dir) assets_extra_file_path = os.path.join(assets_extra_dir, 'assets_extra_file') with fileio.open(assets_extra_file_path, 'wb') as f: f.write(six.ensure_binary('assets_extra_file')) src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL, src_model_path) dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL, dst_model_path) tfrw = tflite_rewriter.TFLiteRewriter(name='myrw', filename='fname', enable_quantization=True) tfrw.perform_rewrite(src_model, dst_model) converter.assert_called_once_with(saved_model_path=mock.ANY, enable_quantization=True) expected_model = os.path.join(dst_model_path, 'fname') self.assertTrue(fileio.exists(expected_model)) with fileio.open(expected_model, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'model') expected_assets_file = os.path.join(dst_model_path, tf.saved_model.ASSETS_DIRECTORY, 'assets_file') with fileio.open(expected_assets_file, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'assets_file') expected_assets_extra_file = os.path.join(dst_model_path, EXTRA_ASSETS_DIRECTORY, 'assets_extra_file') with fileio.open(expected_assets_extra_file, 'rb') as f: self.assertEqual(six.ensure_text(f.readline()), 'assets_extra_file')
def testPipelineSavePipelineArgs(self): os.environ['TFX_JSON_EXPORT_PIPELINE_ARGS_PATH'] = self._tmp_file pipeline.Pipeline( pipeline_name='a', pipeline_root='b', log_root='c', components=[ _make_fake_component_instance('component_a', _OutputTypeA, {}, {}) ], metadata_connection_config=self._metadata_connection_config) self.assertTrue(fileio.exists(self._tmp_file))
def _valid_create_and_check(self, pipeline_path: Text, pipeline_name: Text) -> None: result = self.runner.invoke(cli_group, [ 'pipeline', 'create', '--engine', 'kubeflow', '--pipeline_path', pipeline_path, '--endpoint', self._endpoint ]) absl.logging.info('[CLI] %s', result.output) self.assertIn('Creating pipeline', result.output) self.assertTrue(fileio.exists(self._pipeline_package_path)) self.assertIn('Pipeline "{}" created successfully.'.format(pipeline_name), result.output)
def _check_pipeline_package_path(self, pipeline_name: Text) -> None: # When unset, search for the workflow file in the current dir. if not self.flags_dict[labels.PIPELINE_PACKAGE_PATH]: self.flags_dict[labels.PIPELINE_PACKAGE_PATH] = os.path.join( os.getcwd(), '{}.tar.gz'.format(pipeline_name)) pipeline_package_path = self.flags_dict[labels.PIPELINE_PACKAGE_PATH] if not fileio.exists(pipeline_package_path): sys.exit( 'Pipeline package not found at {}. When --package_path is unset, it will try to find the workflow file, "<pipeline_name>.tar.gz" in the current directory.' .format(pipeline_package_path))
def list_pipelines(self) -> None: """List all the pipelines in the environment.""" if not fileio.exists(self._handler_home_dir): click.echo('No pipelines to display.') return pipelines_list = fileio.listdir(self._handler_home_dir) # Print every pipeline name in a new line. click.echo('-' * 30) click.echo('\n'.join(pipelines_list)) click.echo('-' * 30)
def assertInfraValidatorPassed(self) -> None: infra_validator_path = os.path.join(self._pipeline_root, 'InfraValidator') blessing_path = os.path.join(self._pipeline_root, 'InfraValidator', 'blessing') executions = fileio.listdir(blessing_path) self.assertGreaterEqual(len(executions), 1) for exec_id in executions: blessing_uri = base_driver._generate_output_uri( # pylint: disable=protected-access infra_validator_path, 'blessing', exec_id) blessed = os.path.join(blessing_uri, 'INFRA_BLESSED') self.assertTrue(fileio.exists(blessed))
def _verify_example_split(self, split_name): self.assertTrue( fileio.exists( os.path.join(self._output_examples_dir, f'Split-{split_name}'))) results = self._get_results( os.path.join(self._output_examples_dir, f'Split-{split_name}'), executor._EXAMPLES_FILE_NAME, tf.train.Example) self.assertTrue(results) self.assertIn('classify_label', results[0].features.feature) self.assertIn('classify_score', results[0].features.feature)
def _check_pipeline_existence(self, pipeline_name: Text, required: bool = True) -> None: """Check if pipeline folder exists and if not, exit system. Args: pipeline_name: Name of the pipeline. required: Set it as True if pipeline needs to exist else set it to False. """ handler_pipeline_path = os.path.join(self._handler_home_dir, pipeline_name) # Check if pipeline folder exists. exists = fileio.exists(handler_pipeline_path) if required and not exists: # Check pipeline directory prior 0.25 and move files to the new location # automatically. old_handler_pipeline_path = os.path.join( self._get_deprecated_handler_home(), pipeline_name) if fileio.exists(old_handler_pipeline_path): fileio.makedirs(os.path.dirname(handler_pipeline_path)) fileio.rename(old_handler_pipeline_path, handler_pipeline_path) engine_flag = self.flags_dict[labels.ENGINE_FLAG] handler_home_variable = engine_flag.upper() + '_HOME' click.echo(( '[WARNING] Pipeline "{pipeline_name}" was found in "{old_path}", ' 'but the location that TFX stores pipeline information was moved ' 'since TFX 0.25.0.\n' '[WARNING] Your files in "{old_path}" was automatically moved to ' 'the new location, "{new_path}".\n' '[WARNING] If you want to keep the files at the old location, set ' '`{handler_home}` environment variable to "{old_handler_home}".' ).format(pipeline_name=pipeline_name, old_path=old_handler_pipeline_path, new_path=handler_pipeline_path, handler_home=handler_home_variable, old_handler_home=self._get_deprecated_handler_home()), err=True) else: sys.exit('Pipeline "{}" does not exist.'.format(pipeline_name)) elif not required and exists: sys.exit('Pipeline "{}" already exists.'.format(pipeline_name))
def _testDo(self, payload_format): exec_properties = { utils.INPUT_BASE_KEY: self._input_data_dir, utils.INPUT_CONFIG_KEY: self._input_config, utils.OUTPUT_CONFIG_KEY: self._output_config, utils.OUTPUT_DATA_FORMAT_KEY: payload_format, } output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. self.examples = standard_artifacts.Examples() self.examples.uri = output_data_dir output_dict = {utils.EXAMPLES_KEY: [self.examples]} # Run executor. import_example_gen = executor.Executor() import_example_gen.Do({}, output_dict, exec_properties) self.assertEqual( artifact_utils.encode_split_names(['train', 'eval']), self.examples.split_names) # Check import_example_gen outputs. train_output_file = os.path.join(self.examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(self.examples.uri, 'eval', 'data_tfrecord-00000-of-00001.gz') # Check import_example_gen outputs. train_output_file = os.path.join(self.examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(self.examples.uri, 'eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(fileio.exists(train_output_file)) self.assertTrue(fileio.exists(eval_output_file)) self.assertGreater( fileio.open(train_output_file).size(), fileio.open(eval_output_file).size())
def read(self): if not self._has_value: file_path = self.uri # Assert there is a file exists. if not fileio.exists(file_path): raise RuntimeError( 'Given path does not exist or is not a valid file: %s' % file_path) serialized_value = fileio.open(file_path, 'rb').read() self._has_value = True self._value = self.decode(serialized_value) return self._value
def _assertHyperparametersAreWritten(self, pipeline_name): """Make sure the tuner execution and hyperpearameters output.""" # There must be only one execution of Tuner. tuner_output_base_dir = os.path.join( self._pipeline_root(pipeline_name), 'Tuner', 'best_hyperparameters') tuner_outputs = fileio.listdir(tuner_output_base_dir) self.assertEqual(1, len(tuner_outputs)) # There must be only one best hyperparameters. best_hyperparameters_uri = os.path.join(tuner_output_base_dir, tuner_outputs[0]) self.assertTrue(fileio.exists(best_hyperparameters_uri))
def testCreatePipeline(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path } handler = kubeflow_v2_handler.KubeflowV2Handler(flags_dict) handler.create_pipeline() handler_pipeline_path = os.path.join( handler._handler_home_dir, self.pipeline_args[labels.PIPELINE_NAME], '') self.assertTrue( fileio.exists( os.path.join(handler_pipeline_path, 'pipeline_args.json')))
def testSavePipeline(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path } handler = airflow_handler.AirflowHandler(flags_dict) pipeline_args = handler._extract_pipeline_args() handler._save_pipeline(pipeline_args) self.assertTrue( fileio.exists( os.path.join(handler._handler_home_dir, self.pipeline_args[labels.PIPELINE_NAME])))
def testTaxiPipelineWarmstart(self): BeamDagRunner().run( taxi_pipeline_warmstart._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=self._module_file, serving_model_dir=self._serving_model_dir, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, beam_pipeline_args=[])) self.assertTrue(fileio.exists(self._serving_model_dir)) self.assertTrue(fileio.exists(self._metadata_path)) metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) self.assertGreaterEqual(artifact_count, execution_count) self.assertEqual(10, execution_count) self.assertPipelineExecution() # Run pipeline again. BeamDagRunner().run( taxi_pipeline_warmstart._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=self._module_file, serving_model_dir=self._serving_model_dir, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, beam_pipeline_args=[])) with metadata.Metadata(metadata_config) as m: # 10 more executions. self.assertEqual(20, len(m.store.get_executions())) # Two trainer outputs. self.assertExecutedTwice('Trainer')
def testDoValidation(self, exec_properties, blessed, has_baseline): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'csv_example_gen') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) model = standard_artifacts.Model() baseline_model = standard_artifacts.Model() model.uri = os.path.join(source_data_dir, 'trainer/current') baseline_model.uri = os.path.join(source_data_dir, 'trainer/previous/') blessing_output = standard_artifacts.ModelBlessing() blessing_output.uri = os.path.join(output_data_dir, 'blessing_output') schema = standard_artifacts.Schema() schema.uri = os.path.join(source_data_dir, 'schema_gen') input_dict = { EXAMPLES_KEY: [examples], MODEL_KEY: [model], SCHEMA_KEY: [schema], } if has_baseline: input_dict[BASELINE_MODEL_KEY] = [baseline_model] # Create output dict. eval_output = standard_artifacts.ModelEvaluation() eval_output.uri = os.path.join(output_data_dir, 'eval_output') blessing_output = standard_artifacts.ModelBlessing() blessing_output.uri = os.path.join(output_data_dir, 'blessing_output') output_dict = { EVALUATION_KEY: [eval_output], BLESSING_KEY: [blessing_output], } # List needs to be serialized before being passed into Do function. exec_properties[EXAMPLE_SPLITS_KEY] = json_utils.dumps(None) # Run executor. evaluator = executor.Executor() evaluator.Do(input_dict, output_dict, exec_properties) # Check evaluator outputs. self.assertTrue( fileio.exists(os.path.join(eval_output.uri, 'eval_config.json'))) self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'metrics'))) self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'plots'))) self.assertTrue( fileio.exists(os.path.join(eval_output.uri, 'validations'))) if blessed: self.assertTrue( fileio.exists(os.path.join(blessing_output.uri, 'BLESSED'))) else: self.assertTrue( fileio.exists(os.path.join(blessing_output.uri, 'NOT_BLESSED')))
def resolve(self, pipeline_root: Text): # Package the given user module file as a Python wheel. module_file = self.component.spec.exec_properties[self.module_file_key] # Perform validation on the given `module_file`. if not module_file: return None elif not isinstance(module_file, Text): # TODO(b/187753042): Deprecate and remove usage of RuntimeParameters for # `module_file` parameters and remove this code path. logging.warning( 'Module file %r for component %s is not a path string; ' 'skipping Python user module wheel packaging.', module_file, self.component) return None elif not fileio.exists(module_file): raise ValueError( 'Specified module file %r for component %s does not exist.' % (module_file, self.component)) # Perform validation on the `pipeline_root`. if not pipeline_root: logging.warning( 'No pipeline root provided; skipping Python user module ' 'wheel packaging for component %s.', self.component) return None pipeline_root_exists = fileio.exists(pipeline_root) if not pipeline_root_exists: fileio.makedirs(pipeline_root) # Perform packaging of the user module. dist_file_path, user_module_path = package_user_module_file( self.component.id, module_file, pipeline_root) # Set the user module key to point to a module in this wheel, and clear the # module path key before returning. self.component.spec.exec_properties[ self.module_path_key] = user_module_path self.component.spec.exec_properties[self.module_file_key] = None return dist_file_path
def testGetSchema(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path, labels.ENDPOINT: self.endpoint, labels.IAP_CLIENT_ID: self.iap_client_id, labels.NAMESPACE: self.namespace, labels.PIPELINE_PACKAGE_PATH: self.pipeline_package_path } handler = kubeflow_handler.KubeflowHandler(flags_dict) handler.create_pipeline() flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_NAME: self.pipeline_name, } # No pipeline root handler = kubeflow_handler.KubeflowHandler(flags_dict) with self.assertRaises(SystemExit) as err: handler.get_schema() self.assertEqual( str(err.exception), 'Create a run before inferring schema. If pipeline is already running, then wait for it to successfully finish.' ) # No SchemaGen output. fileio.makedirs(self.pipeline_root) with self.assertRaises(SystemExit) as err: handler.get_schema() self.assertEqual( str(err.exception), 'Either SchemaGen component does not exist or pipeline is still running. If pipeline is running, then wait for it to successfully finish.' ) # Successful pipeline run. # Create fake schema in pipeline root. component_output_dir = os.path.join(self.pipeline_root, 'SchemaGen') schema_path = base_driver._generate_output_uri( # pylint: disable=protected-access component_output_dir, 'schema', 3) fileio.makedirs(schema_path) with open(os.path.join(schema_path, 'schema.pbtxt'), 'w') as f: f.write('SCHEMA') with self.captureWritesToStream(sys.stdout) as captured: handler.get_schema() curr_dir_path = os.path.join(os.getcwd(), 'schema.pbtxt') self.assertIn('Path to schema: {}'.format(curr_dir_path), captured.contents()) self.assertIn( '*********SCHEMA FOR {}**********'.format( self.pipeline_name.upper()), captured.contents()) self.assertTrue(fileio.exists(curr_dir_path))
def _update_execution_proto( self, execution: metadata_store_pb2.Execution, pipeline_info: Optional[data_types.PipelineInfo] = None, component_info: Optional[data_types.ComponentInfo] = None, state: Optional[Text] = None, exec_properties: Optional[Dict[Text, Any]] = None, ) -> metadata_store_pb2.Execution: """Updates the execution proto with given type and state.""" if state is not None: execution.properties[ _EXECUTION_TYPE_KEY_STATE].string_value = tf.compat.as_text( state) # Forward-compatible change to leverage built-in schema to track states. if state == EXECUTION_STATE_CACHED: execution.last_known_state = metadata_store_pb2.Execution.CACHED elif state == EXECUTION_STATE_COMPLETE: execution.last_known_state = metadata_store_pb2.Execution.COMPLETE elif state == EXECUTION_STATE_NEW: execution.last_known_state = metadata_store_pb2.Execution.RUNNING exec_properties = exec_properties or {} # TODO(ruoyu): Enforce a formal rule for execution schema change. for k, v in exec_properties.items(): # We always convert execution properties to unicode. execution.properties[k].string_value = tf.compat.as_text( tf.compat.as_str_any(v)) # We also need to checksum UDF file to identify different binary being # used. Do we have a better way to checksum a file than hashlib.md5? # TODO(ruoyu): Find a better place / solution to the checksum logic. # TODO(ruoyu): SHA instead of MD5. if 'module_file' in exec_properties and exec_properties[ 'module_file'] and fileio.exists( exec_properties['module_file']): contents = file_io.read_file_to_string( exec_properties['module_file']) execution.properties[ 'checksum_md5'].string_value = tf.compat.as_text( tf.compat.as_str_any( hashlib.md5(tf.compat.as_bytes(contents)).hexdigest())) if pipeline_info: execution.properties[ 'pipeline_name'].string_value = pipeline_info.pipeline_name execution.properties[ 'pipeline_root'].string_value = pipeline_info.pipeline_root if pipeline_info.run_id: execution.properties[ 'run_id'].string_value = pipeline_info.run_id if component_info: execution.properties[ 'component_id'].string_value = component_info.component_id return execution
def testCreatePipelineWithFlags(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path } handler = local_handler.LocalHandler(flags_dict) # Pipeline creation should not be affected by additional flags. with mock.patch.object(sys, 'argv', new=(sys.argv + ['--unexpected_flag'])): handler.create_pipeline() self.assertTrue( fileio.exists(handler._get_pipeline_args_path(self.pipeline_name)))
def testIrisPipelineBeam(self): BeamDagRunner().run( iris_pipeline_beam._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=self._module_file, serving_model_dir=self._serving_model_dir, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, beam_pipeline_args=[])) self.assertTrue(fileio.exists(self._serving_model_dir)) self.assertTrue(fileio.exists(self._metadata_path)) metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) self.assertGreaterEqual(artifact_count, execution_count) self.assertEqual(8, execution_count) # 7 components + 1 resolver self.assertPipelineExecution()
def testTaxiPipelineBeam(self): LocalDagRunner().run( taxi_pipeline_local._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=self._module_file, serving_model_dir=self._serving_model_dir, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, beam_pipeline_args=[])) self.assertTrue(fileio.exists(self._serving_model_dir)) self.assertTrue(fileio.exists(self._metadata_path)) metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) self.assertGreaterEqual(artifact_count, execution_count) self.assertEqual(10, execution_count) self.assertPipelineExecution()
def list_pipelines(self) -> None: """List all the pipelines in the environment.""" # There is no managed storage for pipeline packages, so CLI consults # local dir to list pipelines. if not fileio.exists(self._handler_home_dir): click.echo('No pipelines to display.') return pipelines_list = fileio.listdir(self._handler_home_dir) # Print every pipeline name in a new line. click.echo('-' * 30) click.echo('\n'.join(pipelines_list)) click.echo('-' * 30)
def test_do_with_empty_transform_splits(self): self._exec_properties['splits_config'] = proto_utils.proto_to_json( transform_pb2.SplitsConfig(analyze=['train'], transform=[])) self._exec_properties['module_file'] = self._module_file self._output_dict[executor.TRANSFORMED_EXAMPLES_KEY] = ( self._transformed_example_artifacts[:1]) self._transform_executor.Do(self._input_dict, self._output_dict, self._exec_properties) self.assertFalse( fileio.exists( os.path.join(self._transformed_example_artifacts[0].uri, 'train'))) self.assertFalse( fileio.exists( os.path.join(self._transformed_example_artifacts[0].uri, 'eval'))) path_to_saved_model = os.path.join( self._transformed_output.uri, tft.TFTransformOutput.TRANSFORM_FN_DIR, tf.saved_model.SAVED_MODEL_FILENAME_PB) self.assertTrue(fileio.exists(path_to_saved_model))
def testCreatePipeline(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path } handler = airflow_handler.AirflowHandler(flags_dict) handler.create_pipeline() handler_pipeline_path = handler._get_pipeline_info_path( self.pipeline_name) self.assertTrue( fileio.exists( os.path.join(handler_pipeline_path, 'test_pipeline_airflow_1.py')))
def testCreatePipeline(self): flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path } handler = kubeflow_v2_handler.KubeflowV2Handler(flags_dict) handler.create_pipeline() handler_pipeline_path = os.path.join(handler._handler_home_dir, self.pipeline_name) self.assertTrue( fileio.exists( os.path.join(handler_pipeline_path, kubeflow_v2_dag_runner_patcher._OUTPUT_FILENAME)))
def testRunStatisticsGen(self): # Prepare the paths test_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'components', 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', tempfile.mkdtemp()), self._testMethodName) statistics_split_names_path = os.path.join(output_data_dir, 'statistics.properties', 'split_names') fileio.makedirs(output_data_dir) # Run StatisticsGen run_component.run_component( full_component_class_name='tfx.components.StatisticsGen', examples_uri=os.path.join(test_data_dir, 'csv_example_gen'), examples_split_names=artifact_utils.encode_split_names( ['train', 'eval']), # Testing that we can set non-string artifact properties examples_version='1', statistics_path=output_data_dir, statistics_split_names_path=statistics_split_names_path, ) # Check the statistics_gen outputs self.assertTrue( fileio.exists( os.path.join(output_data_dir, 'Split-train', 'FeatureStats.pb'))) self.assertTrue( fileio.exists( os.path.join(output_data_dir, 'Split-eval', 'FeatureStats.pb'))) self.assertTrue(os.path.exists(statistics_split_names_path)) self.assertEqual( pathlib.Path(statistics_split_names_path).read_text(), '["train", "eval"]')