def run_fn(fn_args: executor.TrainerFnArgs): """Train the model based on given args. Args: fn_args: Holds args used to train the model as name/value pairs. """ schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema()) training_spec = trainer_fn(fn_args, schema) # Train the model absl.logging.info('Training model.') tf.estimator.train_and_evaluate(training_spec['estimator'], training_spec['train_spec'], training_spec['eval_spec']) absl.logging.info('Training complete. Model written to %s', fn_args.serving_model_dir) # Export an eval savedmodel for TFMA # NOTE: When trained in distributed training cluster, eval_savedmodel must be # exported only by the chief worker. absl.logging.info('Exporting eval_savedmodel for TFMA.') tfma.export.export_eval_savedmodel( estimator=training_spec['estimator'], export_dir_base=fn_args.eval_model_dir, eval_input_receiver_fn=training_spec['eval_input_receiver_fn']) # Simulate writing a log to the path given by fn_args io_utils.write_string_file( os.path.join(fn_args.model_run_dir, 'fake_log.txt'), '') absl.logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir)
def setUp(self): self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) self._job_dir = os.path.join(self._output_data_dir, 'jobDir') self._fake_package = os.path.join(self._output_data_dir, 'fake_package') self._project_id = '12345' io_utils.write_string_file(self._fake_package, 'fake package content') self._mock_api_client = mock.Mock() self._inputs = {} self._outputs = {} self._training_inputs = { 'project': self._project_id, 'jobDir': self._job_dir, } self._exec_properties = { 'custom_config': { 'gaip_training_args': self._training_inputs }, } self._cmle_serving_args = { 'model_name': 'model_name', 'project_id': self._project_id, }
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: if exec_properties.get(_TUNE_ARGS_KEY): raise ValueError( "TuneArgs is not supported for default Tuner's Executor.") tuner_fn = udf_utils.get_fn(exec_properties, 'tuner_fn') fn_args = fn_args_utils.get_common_fn_args(input_dict, exec_properties, self._get_tmp_dir()) tuner_fn_result = tuner_fn(fn_args) tuner = tuner_fn_result.tuner fit_kwargs = tuner_fn_result.fit_kwargs # TODO(b/156966497): set logger for printing. tuner.search_space_summary() absl.logging.info('Start tuning...') tuner.search(**fit_kwargs) tuner.results_summary() best_hparams_config = tuner.get_best_hyperparameters()[0].get_config() absl.logging.info('Best hyperParameters: %s' % best_hparams_config) best_hparams_path = os.path.join( artifact_utils.get_single_uri( output_dict[_BEST_HYPERPARAMETERS_KEY]), _DEFAULT_FILE_NAME) io_utils.write_string_file(best_hparams_path, json.dumps(best_hparams_config)) absl.logging.info('Best Hyperparameters are written to %s.' % best_hparams_path)
def testDriverJsonContract(self): # This test is identical to testDriverWithoutSpan, but uses raw JSON strings # for inputs and expects against the raw JSON output of the driver, to # better illustrate the JSON I/O contract of the driver. split1 = os.path.join(_TEST_INPUT_DIR, 'split1', 'data') io_utils.write_string_file(split1, 'testing') os.utime(split1, (0, 1)) split2 = os.path.join(_TEST_INPUT_DIR, 'split2', 'data') io_utils.write_string_file(split2, 'testing2') os.utime(split2, (0, 3)) serialized_args = [ 'driver.py', '--json_serialized_invocation_args', self._executor_invocation_from_file ] # Invoke the driver driver.main(serialized_args) # Check the output metadata file for the expected outputs with open(_TEST_OUTPUT_METADATA_JSON) as output_meta_json: self.assertDictEqual( json.loads(re.sub(r'\s+', '', output_meta_json.read())), json.loads(re.sub(r'\s+', '', self._expected_result_from_file)))
def _mock_subprocess_call(cmd: Sequence[Optional[Text]], env: Mapping[Text, Text]) -> int: """Mocks the subprocess call.""" assert len(cmd) == 2, 'Unexpected number of commands: {}'.format(cmd) del env dsl_path = cmd[1] if dsl_path.endswith('test_pipeline_bad.py'): sys.exit(1) if not dsl_path.endswith( 'test_pipeline_1.py') and not dsl_path.endswith( 'test_pipeline_2.py'): raise ValueError('Unexpected dsl path: {}'.format(dsl_path)) spec_pb = pipeline_pb2.PipelineSpec( pipeline_info=pipeline_pb2.PipelineInfo(name='chicago_taxi_kubeflow')) runtime_pb = pipeline_pb2.PipelineJob.RuntimeConfig( gcs_output_directory=os.path.join(os.environ['HOME'], 'tfx', 'pipelines', 'chicago_taxi_kubeflow')) job_pb = pipeline_pb2.PipelineJob(runtime_config=runtime_pb) job_pb.pipeline_spec.update(json_format.MessageToDict(spec_pb)) io_utils.write_string_file( file_name='pipeline.json', string_value=json_format.MessageToJson(message=job_pb, sort_keys=True)) return 0
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: # KerasTuner generates tuning state (e.g., oracle, trials) to working dir. working_dir = self._get_tmp_dir() train_path = artifact_utils.get_split_uri(input_dict['examples'], 'train') eval_path = artifact_utils.get_split_uri(input_dict['examples'], 'eval') schema_file = io_utils.get_only_uri_in_dir( artifact_utils.get_single_uri(input_dict['schema'])) schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema()) tuner_fn = self._GetTunerFn(exec_properties) tuner_spec = tuner_fn(working_dir, io_utils.all_files_pattern(train_path), io_utils.all_files_pattern(eval_path), schema) tuner = tuner_spec.tuner tuner.search_space_summary() # TODO(jyzhao): assert v2 behavior as KerasTuner doesn't work in v1. # TODO(jyzhao): make epochs configurable. tuner.search( tuner_spec.train_dataset, epochs=5, validation_data=tuner_spec.eval_dataset) tuner.results_summary() best_hparams = tuner.oracle.get_best_trials( 1)[0].hyperparameters.get_config() best_hparams_path = os.path.join( artifact_utils.get_single_uri(output_dict['study_best_hparams_path']), _DEFAULT_FILE_NAME) io_utils.write_string_file(best_hparams_path, json.dumps(best_hparams)) absl.logging.info('Best HParams is written to %s.' % best_hparams_path)
def Do(self, input_dict: Dict[Text, List[types.TfxType]], output_dict: Dict[Text, List[types.TfxType]], exec_properties: Dict[Text, Any]) -> None: """Get human review result on a model through Slack channel. Args: input_dict: Input dict from input key to a list of artifacts, including: - model_export: exported model from trainer. - model_blessing: model blessing path from model_validator. output_dict: Output dict from key to a list of artifacts, including: - slack_blessing: model blessing result. exec_properties: A dict of execution properties, including: - slack_token: Token used to setup connection with slack server. - channel_id: The id of the Slack channel to send and receive messages. - timeout_sec: How long do we wait for response, in seconds. Returns: None """ self._log_startup(input_dict, output_dict, exec_properties) # Fetch execution properties from exec_properties dict. slack_token = exec_properties['slack_token'] channel_id = exec_properties['channel_id'] timeout_sec = exec_properties['timeout_sec'] # Fetch input URIs from input_dict. model_export_uri = types.get_single_uri(input_dict['model_export']) model_blessing_uri = types.get_single_uri(input_dict['model_blessing']) # Fetch output artifact from output_dict. slack_blessing = types.get_single_instance(output_dict['slack_blessing']) # We only consider a model as blessed if both of the following conditions # are met: # - The model is blessed by model validator. This is determined by looking # for file named 'BLESSED' from the output from Model Validator. # - The model is blessed by a human reviewer. This logic is in # _fetch_slack_blessing(). try: with Timeout(timeout_sec): blessed = tf.gfile.Exists(os.path.join( model_blessing_uri, 'BLESSED')) and self._fetch_slack_blessing( slack_token, channel_id, model_export_uri) except TimeoutError: # pylint: disable=undefined-variable tf.logging.info('Timeout fetching manual model evaluation result.') blessed = False # If model is blessed, write an empty file named 'BLESSED' in the assigned # output path. Otherwise, write an empty file named 'NOT_BLESSED' instead. if blessed: io_utils.write_string_file( os.path.join(slack_blessing.uri, 'BLESSED'), '') slack_blessing.set_int_custom_property('blessed', 1) else: io_utils.write_string_file( os.path.join(slack_blessing.uri, 'NOT_BLESSED'), '') slack_blessing.set_int_custom_property('blessed', 0) tf.logging.info('Blessing result %s written to %s.', blessed, slack_blessing.uri)
def testGetOnlyDirInDir(self): top_level_dir = os.path.join(self._base_dir, 'dir_1') dir_path = os.path.join(top_level_dir, 'dir_2') file_path = os.path.join(dir_path, 'file') io_utils.write_string_file(file_path, 'testing') self.assertEqual('dir_2', os.path.basename( io_utils.get_only_uri_in_dir(top_level_dir)))
def testStubExecutor(self, mock_publisher): # verify whether base stub executor substitution works mock_publisher.return_value.publish_execution.return_value = {} record_file = os.path.join(self.record_dir, 'output', 'recorded.txt') io_utils.write_string_file(record_file, 'hello world') component_ids = ['_FakeComponent.FakeComponent'] my_stub_launcher = \ stub_component_launcher.get_stub_launcher_class( test_data_dir=self.record_dir, stubbed_component_ids=component_ids, stubbed_component_map={}) launcher = my_stub_launcher.create( component=self.component, pipeline_info=self.pipeline_info, driver_args=self.driver_args, metadata_connection=self.metadata_connection, beam_pipeline_args=[], additional_pipeline_args={}) launcher.launch() output_path = self.component.outputs['output'].get()[0].uri copied_file = os.path.join(output_path, 'recorded.txt') self.assertTrue(tf.io.gfile.exists(copied_file)) contents = io_utils.read_string_file(copied_file) self.assertEqual('hello world', contents)
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: """Stores `custom_config` as an artifact of type `artifacts.PipelineConfiguration`. Args: input_dict: Empty output_dict: Output dict from key to a list of artifacts, including: - pipeline_configuration: A list of type `artifacts.PipelineConfiguration` exec_properties: A dict of execution properties, including: - custom_config: the configuration to save. Returns: None Raises: OSError and its subclasses ValueError """ self._log_startup(input_dict, output_dict, exec_properties) pipeline_configuration = artifact_utils.get_single_instance(output_dict[PIPELINE_CONFIGURATION_KEY]) custom_config = exec_properties.get(CUSTOM_CONFIG_KEY, "{}") output_dir = artifact_utils.get_single_uri([pipeline_configuration]) output_file = os.path.join(output_dir, 'custom_config.json') io_utils.write_string_file(output_file, custom_config)
def testPipelineCompile(self): # Invalid DSL path pipeline_path = os.path.join(self._testdata_dir, 'test_pipeline_flink.py') result = self.runner.invoke(cli_group, [ 'pipeline', 'compile', '--engine', 'beam', '--pipeline_path', pipeline_path ]) self.assertIn('CLI', result.output) self.assertIn('Compiling pipeline', result.output) self.assertIn('Invalid pipeline path: {}'.format(pipeline_path), result.output) # Wrong Runner. pipeline_path = os.path.join(self.tmp_dir, 'empty_file.py') io_utils.write_string_file(pipeline_path, '') result = self.runner.invoke(cli_group, [ 'pipeline', 'compile', '--engine', 'beam', '--pipeline_path', pipeline_path ]) self.assertIn('CLI', result.output) self.assertIn('Compiling pipeline', result.output) self.assertIn('Cannot find BeamDagRunner.run()', result.output) # Successful compilation. pipeline_path = os.path.join(self._testdata_dir, 'test_pipeline_beam_2.py') result = self.runner.invoke(cli_group, [ 'pipeline', 'compile', '--engine', 'beam', '--pipeline_path', pipeline_path ]) self.assertIn('CLI', result.output) self.assertIn('Compiling pipeline', result.output) self.assertIn('Pipeline compiled successfully', result.output)
def Do(self, input_dict: Dict[Text, List[types.TfxArtifact]], output_dict: Dict[Text, List[types.TfxArtifact]], exec_properties: Dict[Text, Any]) -> None: """Get human review result on a model through Slack channel. Args: input_dict: Input dict from input key to a list of artifacts, including: - input_example: an example for an input output_dict: Output dict from key to a list of artifacts, including: - output_example: an example for an output exec_properties: A dict of execution properties, including: - string_parameter: An string execution parameter (only used in here, not persistent or shared up stream) - integer_parameter: An integer execution parameter (only used in here, not persistent or shared up stream) - input_config: not of concern here, only relevant for Driver - output_config: not of concern here, only relevant for Driver Returns: None """ self._log_startup(input_dict, output_dict, exec_properties) # Fetch execution properties from exec_properties dict. string_parameter = exec_properties['string_execution_parameter'] integer_parameter = exec_properties['integer_execution_parameter'] # Fetch input URIs from input_dict. input_example_uri = types.get_single_uri(input_dict['input_example']) # Fetch output artifact from output_dict. output_example = types.get_single_instance( output_dict['output_example']) print("I AM RUNNING!") print(string_parameter) print(integer_parameter) print(input_example_uri) print(output_example) input_data = "" # load your input if tf.gfile.Exists(input_example_uri): with open(input_example_uri, "r") as file: input_data = file.read() # make some changes output_data = input_data + " changed by an awesome custom executor!" # update output uri for up stream components to know the filename output_example.uri = os.path.join(output_example.uri, _DEFAULT_FILE_NAME) # write the changes back to your output io_utils.write_string_file(output_example.uri, output_data) # you can also set custom properties to make checks in up stream components more quickly. # this is optional. output_example.set_string_custom_property('stringProperty', "Awesome") output_example.set_int_custom_property('intProperty', 42)
def testRangeConfigSpanWidthPresence(self): # Test RangeConfig.static_range behavior when span width is not given. span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1', 'data') io_utils.write_string_file(span1_split1, 'testing11') range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=1, end_span_number=1)) splits1 = [ example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN}/split1/*') ] # RangeConfig cannot find zero padding span without width modifier. with self.assertRaisesRegexp(ValueError, 'Cannot find matching for split'): utils.calculate_splits_fingerprint_span_and_version( self._input_base_path, splits1, range_config=range_config) splits2 = [ example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN:2}/split1/*') ] # With width modifier in span spec, RangeConfig.static_range makes # correct zero-padded substitution. _, span, version = utils.calculate_splits_fingerprint_span_and_version( self._input_base_path, splits2, range_config=range_config) self.assertEqual(span, 1) self.assertIsNone(version) self.assertEqual(splits2[0].pattern, 'span01/split1/*')
def testStubExecutor(self, mock_publisher): # verify whether base stub executor substitution works mock_publisher.return_value.publish_execution.return_value = {} record_file = os.path.join(self.record_dir, self.component.id, self.output_key, '0', 'recorded.txt') io_utils.write_string_file(record_file, 'hello world') stub_component_launcher.StubComponentLauncher.initialize( test_data_dir=self.record_dir, test_component_ids=[]) launcher = stub_component_launcher.StubComponentLauncher.create( component=self.component, pipeline_info=self.pipeline_info, driver_args=self.driver_args, metadata_connection=self.metadata_connection, beam_pipeline_args=[], additional_pipeline_args={}) launcher.launch() output_path = self.component.outputs[self.output_key].get()[0].uri copied_file = os.path.join(output_path, 'recorded.txt') self.assertTrue(fileio.exists(copied_file)) contents = io_utils.read_string_file(copied_file) self.assertEqual('hello world', contents)
def testExecutor(self, mock_publisher): # verify whether original executors can run mock_publisher.return_value.publish_execution.return_value = {} io_utils.write_string_file(os.path.join(self.input_dir, 'result.txt'), 'test') stub_component_launcher.StubComponentLauncher.initialize( test_data_dir=self.record_dir, test_component_ids=[self.component.id]) launcher = stub_component_launcher.StubComponentLauncher.create( component=self.component, pipeline_info=self.pipeline_info, driver_args=self.driver_args, metadata_connection=self.metadata_connection, beam_pipeline_args=[], additional_pipeline_args={}) self.assertEqual( launcher._component_info.component_type, # pylint: disable=protected-access '.'.join([ # pylint: disable=protected-access test_utils._FakeComponent.__module__, # pylint: disable=protected-access test_utils._FakeComponent.__name__ # pylint: disable=protected-access ])) launcher.launch() output_path = self.component.outputs[self.output_key].get()[0].uri self.assertTrue(fileio.exists(output_path)) contents = io_utils.read_string_file(output_path) self.assertEqual('test', contents)
def testVersionWidth(self): split1 = os.path.join(self._input_base_path, 'span1', 'ver1', 'split1', 'data') io_utils.write_string_file(split1, 'testing') splits = [ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN}/ver{VERSION:2}/split1/*') ] # TODO(jjma): find a better way of describing this error to user. with self.assertRaisesRegexp( ValueError, 'Glob pattern does not match regex pattern'): utils.calculate_splits_fingerprint_span_and_version( self._input_base_path, splits) splits = [ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN}/ver{VERSION:1}/split1/*') ] _, span, version = utils.calculate_splits_fingerprint_span_and_version( self._input_base_path, splits) self.assertEqual(span, 1) self.assertEqual(version, 1)
def Do(self, input_dict: Dict[str, List[types.Artifact]], output_dict: Dict[str, List[types.Artifact]], exec_properties: Dict[str, Any]) -> None: if tfx_tuner.get_tune_args(exec_properties): raise ValueError( "TuneArgs is not supported by this Tuner's Executor.") metalearning_algorithm = None if 'metalearning_algorithm' in exec_properties: metalearning_algorithm = exec_properties.get( 'metalearning_algorithm') warmup_trials = 0 warmup_trial_data = None if metalearning_algorithm: warmup_tuner, warmup_trials = self.warmup(input_dict, exec_properties, metalearning_algorithm) warmup_trial_data = extract_tuner_trial_progress(warmup_tuner) else: logging.info('MetaLearning Algorithm not provided.') # Create new fn_args for final tuning stage. fn_args = fn_args_utils.get_common_fn_args( input_dict, exec_properties, working_dir=self._get_tmp_dir()) tuner_fn = udf_utils.get_fn(exec_properties, 'tuner_fn') tuner_fn_result = tuner_fn(fn_args) tuner_fn_result.tuner.oracle.max_trials = max( (tuner_fn_result.tuner.oracle.max_trials - warmup_trials), 1) tuner = self.search(tuner_fn_result) tuner_trial_data = extract_tuner_trial_progress(tuner) if warmup_trial_data: cumulative_tuner_trial_data, best_tuner_ix = merge_trial_data( warmup_trial_data, tuner_trial_data) cumulative_tuner_trial_data[ 'warmup_trial_data'] = warmup_trial_data[BEST_CUMULATIVE_SCORE] cumulative_tuner_trial_data['tuner_trial_data'] = tuner_trial_data[ BEST_CUMULATIVE_SCORE] if isinstance(tuner.oracle.objective, kerastuner.Objective): cumulative_tuner_trial_data[ 'objective'] = tuner.oracle.objective.name else: cumulative_tuner_trial_data[ 'objective'] = 'objective not understood' tuner_trial_data = cumulative_tuner_trial_data best_tuner = warmup_tuner if best_tuner_ix == 0 else tuner else: best_tuner = tuner tfx_tuner.write_best_hyperparameters(best_tuner, output_dict) tuner_plot_path = os.path.join( artifact_utils.get_single_uri(output_dict['trial_summary_plot']), 'tuner_plot_data.txt') io_utils.write_string_file(tuner_plot_path, json.dumps(tuner_trial_data)) logging.info('Tuner plot data written at: %s', tuner_plot_path)
def testSpanWrongFormat(self): wrong_span = os.path.join(self._input_base_path, 'spanx', 'split1', 'data') io_utils.write_string_file(wrong_span, 'testing_wrong_span') with self.assertRaisesRegexp(ValueError, 'Cannot not find span number'): self._example_gen_driver.resolve_input_artifacts(self._input_channels, self._exec_properties, None, None)
def testResolveInputArtifacts(self): # Create input splits. split1 = os.path.join(self._input_base_path, 'split1', 'data') io_utils.write_string_file(split1, 'testing') os.utime(split1, (0, 1)) split2 = os.path.join(self._input_base_path, 'split2', 'data') io_utils.write_string_file(split2, 'testing2') os.utime(split2, (0, 3)) # Mock artifact. artifacts = [] for i in [4, 3, 2, 1]: artifact = metadata_store_pb2.Artifact() artifact.id = i artifact.uri = self._input_base_path artifact.custom_properties['span'].string_value = '0' # Only odd ids will be matched if i % 2 == 1: artifact.custom_properties[ 'input_fingerprint'].string_value = 'split:s1,num_files:1,total_bytes:7,xor_checksum:1,sum_checksum:1\nsplit:s2,num_files:1,total_bytes:8,xor_checksum:3,sum_checksum:3' else: artifact.custom_properties[ 'input_fingerprint'].string_value = 'not_match' artifacts.append(artifact) # Create exec proterties. exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s1', pattern='split1/*'), example_gen_pb2.Input.Split(name='s2', pattern='split2/*') ]), preserving_proto_field_name=True), } # Cache not hit. self._mock_metadata.get_artifacts_by_uri.return_value = [artifacts[0]] self._mock_metadata.publish_artifacts.return_value = [artifacts[3]] updated_input_dict = self._example_gen_driver.resolve_input_artifacts( self._input_channels, exec_properties, None, None) self.assertEqual(1, len(updated_input_dict)) self.assertEqual(1, len(updated_input_dict['input_base'])) updated_input_base = updated_input_dict['input_base'][0] self.assertEqual(1, updated_input_base.id) self.assertEqual(self._input_base_path, updated_input_base.uri) # Cache hit. self._mock_metadata.get_artifacts_by_uri.return_value = artifacts self._mock_metadata.publish_artifacts.return_value = [] updated_input_dict = self._example_gen_driver.resolve_input_artifacts( self._input_channels, exec_properties, None, None) self.assertEqual(1, len(updated_input_dict)) self.assertEqual(1, len(updated_input_dict['input_base'])) updated_input_base = updated_input_dict['input_base'][0] self.assertEqual(3, updated_input_base.id) self.assertEqual(self._input_base_path, updated_input_base.uri)
def testCopyFile(self): file_path = os.path.join(self._base_dir, 'temp_file') io_utils.write_string_file(file_path, 'testing') copy_path = os.path.join(self._base_dir, 'copy_file') io_utils.copy_file(file_path, copy_path) self.assertTrue(file_io.file_exists(copy_path)) f = file_io.FileIO(file_path, mode='r') self.assertEqual('testing', f.read()) self.assertEqual(7, f.tell())
def testCopyDir(self): old_path = os.path.join(self._base_dir, 'old', 'path') new_path = os.path.join(self._base_dir, 'new', 'path') io_utils.write_string_file(old_path, 'testing') io_utils.copy_dir(os.path.dirname(old_path), os.path.dirname(new_path)) self.assertTrue(file_io.file_exists(new_path)) f = file_io.FileIO(new_path, mode='r') self.assertEqual('testing', f.read()) self.assertEqual(7, f.tell())
def copy_and_change_pipeline_name(orig_path: str, new_path: str, origin_pipeline_name: str, new_pipeline_name: str) -> None: """Copy pipeline file to new path with pipeline name changed.""" contents = io_utils.read_string_file(orig_path) assert contents.count(origin_pipeline_name ) == 1, 'DSL file can only contain one pipeline name' contents = contents.replace(origin_pipeline_name, new_pipeline_name) io_utils.write_string_file(new_path, contents)
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: # Don't call super().Do() to skip copying. absl.logging.info('Running CustomStubExecutor') for artifact_list in output_dict.values(): for artifact in artifact_list: custom_output_path = os.path.join(artifact.uri, 'result.txt') io_utils.write_string_file(custom_output_path, 'custom component')
def _uncommentMultiLineVariables(self, filepath: Text, variables: Iterable[Text]) -> Text: """Update given file by uncommenting a variable. The variable should be defined in following form. # .... # VARIABLE_NAME = ... # long indented line # # long indented line # OTHER STUFF Above comments will become # .... VARIABLE_NAME = ... long indented line long indented line # OTHER STUFF Arguments: filepath: file to modify. variables: List of variables. Returns: Absolute path of the modified file. """ path = os.path.join(self._project_dir, filepath) result = [] commented_variables = [ '# ' + variable + ' =' for variable in variables ] in_variable_definition = False with open(path) as fp: for line in fp: if in_variable_definition: if line.startswith('# ') or line.startswith('# }'): result.append(line[2:]) continue elif line == '#\n': result.append(line[1:]) continue else: in_variable_definition = False for commented_var in commented_variables: if line.startswith(commented_var): in_variable_definition = True result.append(line[2:]) break else: # doesn't include a variable definition to uncomment. result.append(line) io_utils.write_string_file(path, ''.join(result)) return path
def _replaceFileContent(self, filepath: Text, replacements: Iterable[Tuple[Text, Text]]) -> Text: """Update given file using `replacements`.""" path = os.path.join(self._project_dir, filepath) with open(path) as fp: content = fp.read() for old, new in replacements: content = content.replace(old, new) io_utils.write_string_file(path, content) return path
def testSpanWrongFormat(self): wrong_span = os.path.join(self._input_base_path, 'spanx', 'split1', 'data') io_utils.write_string_file(wrong_span, 'testing_wrong_span') splits = [ example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN}/split1/*') ] with self.assertRaisesRegex(ValueError, 'Cannot find span number'): utils.calculate_splits_fingerprint_span_and_version( self._input_base_path, splits)
def _addAllComponents(self): """Change 'pipeline.py' file to put all components into the pipeline.""" pipeline_definition_file = os.path.join(self._project_dir, 'pipeline.py') with open(pipeline_definition_file) as fp: content = fp.read() # At the initial state, these are commented out. Uncomment them. content = content.replace('# components.append(', 'components.append(') io_utils.write_string_file(pipeline_definition_file, content) return pipeline_definition_file
def testVersionNoMatching(self): span_dir = os.path.join(self._input_base_path, 'span01', 'wrong', 'data') io_utils.write_string_file(span_dir, 'testing_version_no_matching') splits = [ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN}/version{VERSION}/split1/*') ] with self.assertRaisesRegex(ValueError, 'Cannot find matching for split'): utils.calculate_splits_fingerprint_span_and_version( self._input_base_path, splits)
def testNoSpanOrVersion(self): # Test specific behavior when neither Span nor Version spec is present. split1 = os.path.join(self._input_base_path, 'split1', 'data') io_utils.write_string_file(split1, 'testing') splits = [example_gen_pb2.Input.Split(name='s1', pattern='split1/*')] _, span, version = utils.calculate_splits_fingerprint_span_and_version( self._input_base_path, splits) self.assertEqual(span, 0) self.assertIsNone(version)
def build_ephemeral_package() -> Text: """Repackage current installation of TFX into a tfx_ephemeral sdist. Returns: Path to ephemeral sdist package. Raises: RuntimeError: if dist directory has zero or multiple files. """ tmp_dir = os.path.join(tempfile.mkdtemp(), 'build', 'tfx') # Find the last directory named 'tfx' in this file's path and package it. path_split = __file__.split(os.path.sep) last_index = -1 for i in range(len(path_split)): if path_split[i] == 'tfx': last_index = i if last_index < 0: raise RuntimeError('Cannot locate directory \'tfx\' in the path %s' % __file__) tfx_root_dir = os.path.sep.join(path_split[0:last_index + 1]) absl.logging.info('Copying all content from install dir %s to temp dir %s', tfx_root_dir, tmp_dir) shutil.copytree(tfx_root_dir, os.path.join(tmp_dir, 'tfx')) # Source directory default permission is 0555 but we need to be able to create # new setup.py file. os.chmod(tmp_dir, 0o720) setup_file = os.path.join(tmp_dir, 'setup.py') absl.logging.info('Generating a temp setup file at %s', setup_file) install_requires = dependencies.make_required_install_packages() io_utils.write_string_file( setup_file, _ephemeral_setup_file.format( version=version.__version__, install_requires=install_requires)) # Create the package curdir = os.getcwd() os.chdir(tmp_dir) temp_log = os.path.join(tmp_dir, 'setup.log') with open(temp_log, 'w') as f: absl.logging.info('Creating temporary sdist package, logs available at %s', temp_log) cmd = [sys.executable, setup_file, 'sdist'] subprocess.call(cmd, stdout=f, stderr=f) os.chdir(curdir) # Return the package dir+filename dist_dir = os.path.join(tmp_dir, 'dist') files = tf.io.gfile.listdir(dist_dir) if not files: raise RuntimeError('Found no package files in %s' % dist_dir) elif len(files) > 1: raise RuntimeError('Found multiple package files in %s' % dist_dir) return os.path.join(dist_dir, files[0])
def testGetOnlyFileInDir(self): file_path = os.path.join(self._base_dir, 'file', 'path') io_utils.write_string_file(file_path, 'testing') self.assertEqual(file_path, io_utils.get_only_uri_in_dir(os.path.dirname(file_path)))
def testDeleteDir(self): file_path = os.path.join(self._base_dir, 'file', 'path') io_utils.write_string_file(file_path, 'testing') self.assertTrue(tf.gfile.Exists(file_path)) io_utils.delete_dir(os.path.dirname(file_path)) self.assertFalse(tf.gfile.Exists(file_path))