def get_fn(exec_properties: Dict[Text, Any], fn_name: Text) -> Callable[..., Any]: """Loads and returns user-defined function.""" logging.error('udf_utils.get_fn %r %r', exec_properties, fn_name) has_module_file = bool(exec_properties.get(_MODULE_FILE_KEY)) has_module_path = bool(exec_properties.get(_MODULE_PATH_KEY)) has_fn = bool(exec_properties.get(fn_name)) if has_module_path: module_path = exec_properties[_MODULE_PATH_KEY] return import_utils.import_func_from_module(module_path, fn_name) elif has_module_file: if has_fn: return import_utils.import_func_from_source( exec_properties[_MODULE_FILE_KEY], exec_properties[fn_name]) else: return import_utils.import_func_from_source( exec_properties[_MODULE_FILE_KEY], fn_name) elif has_fn: fn_path_split = exec_properties[fn_name].split('.') return import_utils.import_func_from_module( '.'.join(fn_path_split[0:-1]), fn_path_split[-1]) else: raise ValueError( 'Neither module file or user function have been supplied in `exec_properties`.' )
def testImportFuncFromSource(self): source_data_dir = os.path.join(os.path.dirname(__file__), 'testdata') test_fn_file = os.path.join(source_data_dir, 'test_fn.ext') fn_1 = import_utils.import_func_from_source(test_fn_file, 'test_fn') fn_2 = import_utils.import_func_from_source(test_fn_file, 'test_fn') self.assertIs(fn_1, fn_2) self.assertEqual(10, fn_1([1, 2, 3, 4]))
def _GetPreprocessingFn(self, inputs: Mapping[Text, Any], unused_outputs: Mapping[Text, Any]) -> Any: """Returns a user defined preprocessing_fn. Args: inputs: A dictionary of labelled input values. unused_outputs: A dictionary of labelled output values. Returns: User defined function. Raises: ValueError: When neither or both of MODULE_FILE and PREPROCESSING_FN are present in inputs. """ has_module_file = bool( common.GetSoleValue(inputs, labels.MODULE_FILE, strict=False)) has_preprocessing_fn = bool( common.GetSoleValue(inputs, labels.PREPROCESSING_FN, strict=False)) if has_module_file == has_preprocessing_fn: raise ValueError( 'Neither or both of MODULE_FILE and PREPROCESSING_FN have been ' 'supplied in inputs.') if has_module_file: return import_utils.import_func_from_source( common.GetSoleValue(inputs, labels.MODULE_FILE), 'preprocessing_fn') preprocessing_fn_path_split = common.GetSoleValue( inputs, labels.PREPROCESSING_FN).split('.') return import_utils.import_func_from_module( '.'.join(preprocessing_fn_path_split[0:-1]), preprocessing_fn_path_split[-1])
def testtestImportFuncFromModuleReload(self): temp_dir = self.create_tempdir().full_path test_fn_file = os.path.join(temp_dir, 'fn.py') with tf.io.gfile.GFile(test_fn_file, mode='w') as f: f.write("""def test_fn(inputs): return sum(inputs) """) i = import_utils._tfx_module_finder.count_registered fn_1 = import_utils.import_func_from_source(test_fn_file, 'test_fn') self.assertEqual(10, fn_1([1, 2, 3, 4])) with tf.io.gfile.GFile(test_fn_file, mode='w') as f: f.write("""def test_fn(inputs): return 1+sum(inputs) """) fn_2 = import_utils.import_func_from_source(test_fn_file, 'test_fn') self.assertEqual(11, fn_2([1, 2, 3, 4])) fn_3 = getattr(importlib.reload(sys.modules['user_module_%d' % i]), 'test_fn') self.assertEqual(11, fn_3([1, 2, 3, 4]))
def _GetPreprocessingFn(self, inputs: Mapping[Text, Any], unused_outputs: Mapping[Text, Any]) -> Any: """Returns a user defined preprocessing_fn. Args: inputs: A dictionary of labelled input values. unused_outputs: A dictionary of labelled output values. Returns: User defined function. """ return import_utils.import_func_from_source( common.GetSoleValue(inputs, labels.PREPROCESSING_FN), 'preprocessing_fn')
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: del input_dict if _MODULE_FILE_KEY in exec_properties: create_decoder_func = import_utils.import_func_from_source( exec_properties.get(_MODULE_FILE_KEY), exec_properties.get(_CREATE_DECODER_FUNC_KEY)) else: create_decoder_func = udf_utils.get_fn(exec_properties, _CREATE_DECODER_FUNC_KEY) tf_graph_record_decoder.save_decoder( create_decoder_func(), value_utils.GetSoleValue(output_dict, _DATA_VIEW_KEY).uri)
def _GetTrainerFn(self, exec_properties: Dict[Text, Any]) -> Any: """Loads and returns user-defined trainer_fn.""" has_module_file = bool(exec_properties.get('module_file')) has_trainer_fn = bool(exec_properties.get('trainer_fn')) if has_module_file == has_trainer_fn: raise ValueError( "Neither or both of 'module_file' 'trainer_fn' have been supplied in " "'exec_properties'.") if has_module_file: return import_utils.import_func_from_source( exec_properties['module_file'], 'trainer_fn') trainer_fn_path_split = exec_properties['trainer_fn'].split('.') return import_utils.import_func_from_module( '.'.join(trainer_fn_path_split[0:-1]), trainer_fn_path_split[-1])
def _GetFn(self, exec_properties: Dict[Text, Any], fn_name: Text) -> Any: """Loads and returns user-defined function.""" has_module_file = bool(exec_properties.get('module_file')) has_fn = bool(exec_properties.get(fn_name)) if has_module_file == has_fn: raise ValueError( 'Neither or both of module file and user function have been supplied in ' "'exec_properties'.") if has_module_file: return import_utils.import_func_from_source( exec_properties['module_file'], fn_name) fn_path_split = exec_properties[fn_name].split('.') return import_utils.import_func_from_module('.'.join(fn_path_split[0:-1]), fn_path_split[-1])
def get_fn(exec_properties: Dict[Text, Any], fn_name: Text) -> Callable[..., Any]: """Loads and returns user-defined function.""" has_module_file = bool(exec_properties.get(_MODULE_FILE_KEY)) has_fn = bool(exec_properties.get(fn_name)) if has_module_file == has_fn: raise ValueError( 'Neither or both of module file and user function have been supplied ' "in 'exec_properties'.") if has_module_file: return import_utils.import_func_from_source( exec_properties[_MODULE_FILE_KEY], fn_name) fn_path_split = exec_properties[fn_name].split('.') return import_utils.import_func_from_module('.'.join(fn_path_split[0:-1]), fn_path_split[-1])
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: """Uses a user-supplied tf.estimator to train a TensorFlow model locally. The Trainer Executor invokes a training_fn callback function provided by the user via the module_file parameter. With the tf.estimator returned by this function, the Trainer Executor then builds a TensorFlow model using the user-provided tf.estimator. Args: input_dict: Input dict from input key to a list of ML-Metadata Artifacts. - examples: Examples used for training, must include 'train' and 'eval' splits. - transform_output: Optional input transform graph. - schema: Schema of the data. output_dict: Output dict from output key to a list of Artifacts. - output: Exported model. exec_properties: A dict of execution properties. - train_args: JSON string of trainer_pb2.TrainArgs instance, providing args for training. - eval_args: JSON string of trainer_pb2.EvalArgs instance, providing args for eval. - module_file: Python module file containing UDF model definition. - warm_starting: Whether or not we need to do warm starting. - warm_start_from: Optional. If warm_starting is True, this is the directory to find previous model to warm start on. Returns: None Raises: None """ self._log_startup(input_dict, output_dict, exec_properties) # TODO(zhitaoli): Deprecate this in a future version. if exec_properties.get('custom_config', None): cmle_args = exec_properties.get('custom_config', {}).get('cmle_training_args') if cmle_args: executor_class_path = '.'.join( [Executor.__module__, Executor.__name__]) tf.logging.warn( 'Passing \'cmle_training_args\' to trainer directly is deprecated, ' 'please use extension executor at ' 'tfx.extensions.google_cloud_ai_platform.trainer.executor instead' ) return runner.start_cmle_training(input_dict, output_dict, exec_properties, executor_class_path, cmle_args) trainer_fn = import_utils.import_func_from_source( exec_properties['module_file'], 'trainer_fn') # Set up training parameters train_files = [ _all_files_pattern( artifact_utils.get_split_uri(input_dict['examples'], 'train')) ] transform_output = artifact_utils.get_single_uri( input_dict['transform_output'] ) if input_dict['transform_output'] else None eval_files = [ _all_files_pattern( artifact_utils.get_split_uri(input_dict['examples'], 'eval')) ] schema_file = io_utils.get_only_uri_in_dir( artifact_utils.get_single_uri(input_dict['schema'])) train_args = trainer_pb2.TrainArgs() eval_args = trainer_pb2.EvalArgs() json_format.Parse(exec_properties['train_args'], train_args) json_format.Parse(exec_properties['eval_args'], eval_args) # https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with # num_steps=None. Conversion of the proto to python will set the default # value of an int as 0 so modify the value here. Tensorflow will raise an # error if num_steps <= 0. train_steps = train_args.num_steps or None eval_steps = eval_args.num_steps or None output_path = artifact_utils.get_single_uri(output_dict['output']) serving_model_dir = path_utils.serving_model_dir(output_path) eval_model_dir = path_utils.eval_model_dir(output_path) # Assemble warm start path if needed. warm_start_from = None if exec_properties.get('warm_starting') and exec_properties.get( 'warm_start_from'): previous_model_dir = os.path.join( exec_properties['warm_start_from'], path_utils.SERVING_MODEL_DIR) if previous_model_dir and tf.gfile.Exists( os.path.join(previous_model_dir, self._CHECKPOINT_FILE_NAME)): warm_start_from = previous_model_dir # TODO(b/126242806) Use PipelineInputs when it is available in third_party. hparams = tf.contrib.training.HParams( # A list of uris for train files. train_files=train_files, # An optional single uri for transform graph produced by TFT. Will be # None if not specified. transform_output=transform_output, # A single uri for the output directory of the serving model. serving_model_dir=serving_model_dir, # A list of uris for eval files. eval_files=eval_files, # A single uri for schema file. schema_file=schema_file, # Number of train steps. train_steps=train_steps, # Number of eval steps. eval_steps=eval_steps, # A single uri for the model directory to warm start from. warm_start_from=warm_start_from) schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema()) training_spec = trainer_fn(hparams, schema) # Train the model tf.logging.info('Training model.') tf.estimator.train_and_evaluate(training_spec['estimator'], training_spec['train_spec'], training_spec['eval_spec']) tf.logging.info('Training complete. Model written to %s', serving_model_dir) # Export an eval savedmodel for TFMA tf.logging.info('Exporting eval_savedmodel for TFMA.') tfma.export.export_eval_savedmodel( estimator=training_spec['estimator'], export_dir_base=eval_model_dir, eval_input_receiver_fn=training_spec['eval_input_receiver_fn']) tf.logging.info('Exported eval_savedmodel to %s.', eval_model_dir)
def testImportFuncFromSourceMissingFunction(self): source_data_dir = os.path.join(os.path.dirname(__file__), 'testdata') test_fn_file = os.path.join(source_data_dir, 'test_fn.ext') with self.assertRaises(AttributeError): import_utils.import_func_from_source(test_fn_file, 'non_existing')