def testAddModuleDependencyAndPackage(self): # Do not test packaging in unsupported environments. if not udf_utils.should_package_user_modules(): return # Create a component with a testing user module file. temp_dir = tempfile.mkdtemp() temp_module_file = os.path.join(temp_dir, 'my_user_module.py') with open(temp_module_file, 'w') as f: f.write('# Test user module file.\nEXPOSED_VALUE="ABC123xyz"') component = _MyComponent(spec=_MyComponentSpec( my_module_file=temp_module_file)) # Add the user module file pip dependency. udf_utils.add_user_module_dependency(component, 'my_module_file', 'my_module_path') self.assertLen(component._pip_dependencies, 1) dependency = component._pip_dependencies[0] self.assertIsInstance(dependency, udf_utils.UserModuleFilePipDependency) self.assertIs(dependency.component, component) self.assertEqual(dependency.module_file_key, 'my_module_file') self.assertEqual(dependency.module_path_key, 'my_module_path') # Resolve the pip dependency and package the user module. temp_pipeline_root = tempfile.mkdtemp() component._resolve_pip_dependencies(temp_pipeline_root) self.assertLen(component._pip_dependencies, 1) dependency = component._pip_dependencies[0] # The hash version is based on the module names and contents and thus # should be stable. self.assertEqual( dependency, os.path.join( temp_pipeline_root, '_wheels', 'tfx_user_code_MyComponent-0.0+' '1c9b861db85cc54c56a56cbf64f77c1b9d1ded487d60a97d082ead6b250ee62c' '-py3-none-any.whl')) # Test import behavior within context manager. with udf_utils.TempPipInstallContext([dependency]): # Test import from same process. import my_user_module # pylint: disable=g-import-not-at-top self.assertEqual(my_user_module.EXPOSED_VALUE, 'ABC123xyz') del sys.modules['my_user_module'] # Test import from a subprocess. self.assertEqual( subprocess.check_output([ sys.executable, '-c', 'import my_user_module; print(my_user_module.EXPOSED_VALUE)' ]), b'ABC123xyz\n') # Test that the import paths are cleaned up, so the user module can no # longer be imported. with self.assertRaises(ModuleNotFoundError): import my_user_module # pylint: disable=g-import-not-at-top
def __init__( self, examples: types.Channel, schema: types.Channel, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, preprocessing_fn: Optional[Union[ Text, data_types.RuntimeParameter]] = None, splits_config: Optional[transform_pb2.SplitsConfig] = None, analyzer_cache: Optional[types.Channel] = None, materialize: bool = True, disable_analyzer_cache: bool = False, force_tf_compat_v1: bool = False, custom_config: Optional[Dict[Text, Any]] = None, disable_statistics: bool = False): """Construct a Transform component. Args: examples: A Channel of type `standard_artifacts.Examples` (required). This should contain custom splits specified in splits_config. If custom split is not provided, this should contain two splits 'train' and 'eval'. schema: A Channel of type `standard_artifacts.Schema`. This should contain a single schema artifact. module_file: The file path to a python module file, from which the 'preprocessing_fn' function will be loaded. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. The function needs to have the following signature: ``` def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]: ... ``` where the values of input and returned Dict are either tf.Tensor or tf.SparseTensor. If additional inputs are needed for preprocessing_fn, they can be passed in custom_config: ``` def preprocessing_fn(inputs: Dict[Text, Any], custom_config: Dict[Text, Any]) -> Dict[Text, Any]: ... ``` Use of a RuntimeParameter for this argument is experimental. preprocessing_fn: The path to python function that implements a 'preprocessing_fn'. See 'module_file' for expected signature of the function. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. Use of a RuntimeParameter for this argument is experimental. splits_config: A transform_pb2.SplitsConfig instance, providing splits that should be analyzed and splits that should be transformed. Note analyze and transform splits can have overlap. Default behavior (when splits_config is not set) is analyze the 'train' split and transform all splits. If splits_config is set, analyze cannot be empty. analyzer_cache: Optional input 'TransformCache' channel containing cached information from previous Transform runs. When provided, Transform will try use the cached calculation if possible. materialize: If True, write transformed examples as an output. disable_analyzer_cache: If False, Transform will use input cache if provided and write cache output. If True, `analyzer_cache` must not be provided. force_tf_compat_v1: (Optional) If True and/or TF2 behaviors are disabled Transform will use Tensorflow in compat.v1 mode irrespective of installed version of Tensorflow. Defaults to `False`. custom_config: A dict which contains additional parameters that will be passed to preprocessing_fn. disable_statistics: If True, do not invoke TFDV to compute pre-transform and post-transform statistics. When statistics are computed, they will will be stored in the `pre_transform_feature_stats/` and `post_transform_feature_stats/` subfolders of the `transform_graph` export. Raises: ValueError: When both or neither of 'module_file' and 'preprocessing_fn' is supplied. """ if bool(module_file) == bool(preprocessing_fn): raise ValueError( "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied." ) transform_graph = types.Channel(type=standard_artifacts.TransformGraph) transformed_examples = None if materialize: transformed_examples = types.Channel( type=standard_artifacts.Examples) transformed_examples.matching_channel_name = 'examples' (pre_transform_schema, pre_transform_stats, post_transform_schema, post_transform_stats, post_transform_anomalies) = (None, ) * 5 if not disable_statistics: pre_transform_schema = types.Channel( type=standard_artifacts.Schema) post_transform_schema = types.Channel( type=standard_artifacts.Schema) pre_transform_stats = types.Channel( type=standard_artifacts.ExampleStatistics) post_transform_stats = types.Channel( type=standard_artifacts.ExampleStatistics) post_transform_anomalies = types.Channel( type=standard_artifacts.ExampleAnomalies) if disable_analyzer_cache: updated_analyzer_cache = None if analyzer_cache: raise ValueError( '`analyzer_cache` is set when disable_analyzer_cache is True.' ) else: updated_analyzer_cache = types.Channel( type=standard_artifacts.TransformCache) spec = standard_component_specs.TransformSpec( examples=examples, schema=schema, module_file=module_file, preprocessing_fn=preprocessing_fn, force_tf_compat_v1=int(force_tf_compat_v1), splits_config=splits_config, transform_graph=transform_graph, transformed_examples=transformed_examples, analyzer_cache=analyzer_cache, updated_analyzer_cache=updated_analyzer_cache, custom_config=json_utils.dumps(custom_config), disable_statistics=int(disable_statistics), pre_transform_schema=pre_transform_schema, pre_transform_stats=pre_transform_stats, post_transform_schema=post_transform_schema, post_transform_stats=post_transform_stats, post_transform_anomalies=post_transform_anomalies) super(Transform, self).__init__(spec=spec) if udf_utils.should_package_user_modules(): # In this case, the `MODULE_PATH_KEY` execution property will be injected # as a reference to the given user module file after packaging, at which # point the `MODULE_FILE_KEY` execution property will be removed. udf_utils.add_user_module_dependency( self, standard_component_specs.MODULE_FILE_KEY, standard_component_specs.MODULE_PATH_KEY)
def __init__( self, examples: Optional[types.Channel] = None, transformed_examples: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, schema: Optional[types.Channel] = None, base_model: Optional[types.Channel] = None, hyperparameters: Optional[types.Channel] = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, # TODO(b/147702778): deprecate trainer_fn. trainer_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, train_args: Optional[Union[trainer_pb2.TrainArgs, Dict[Text, Any]]] = None, eval_args: Optional[Union[trainer_pb2.EvalArgs, Dict[Text, Any]]] = None, custom_config: Optional[Dict[Text, Any]] = None, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None): """Construct a Trainer component. Args: examples: A Channel of type `standard_artifacts.Examples`, serving as the source of examples used in training (required). May be raw or transformed. transformed_examples: Deprecated (no compatibility guarantee). Please set 'examples' instead. transform_graph: An optional Channel of type `standard_artifacts.TransformGraph`, serving as the input transform graph if present. schema: An optional Channel of type `standard_artifacts.Schema`, serving as the schema of training and eval data. Schema is optional when 1) transform_graph is provided which contains schema. 2) user module bypasses the usage of schema, e.g., hardcoded. base_model: A Channel of type `Model`, containing model that will be used for training. This can be used for warmstart, transfer learning or model ensembling. hyperparameters: A Channel of type `standard_artifacts.HyperParameters`, serving as the hyperparameters for training module. Tuner's output best hyperparameters can be feed into this. module_file: A path to python module file containing UDF model definition. The module_file must implement a function named `run_fn` at its top level with function signature: `def run_fn(trainer.fn_args_utils.FnArgs)`, and the trained model must be saved to FnArgs.serving_model_dir when this function is executed. For Estimator based Executor, The module_file must implement a function named `trainer_fn` at its top level. The function must have the following signature. def trainer_fn(trainer.fn_args_utils.FnArgs, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma EvalInputReceiver. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor (default). Use of a RuntimeParameter for this argument is experimental. run_fn: A python path to UDF model definition function for generic trainer. See 'module_file' for details. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor (default). Use of a RuntimeParameter for this argument is experimental. trainer_fn: A python path to UDF model definition function for estimator based trainer. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer uses Estimator based Executor. Use of a RuntimeParameter for this argument is experimental. train_args: A proto.TrainArgs instance or a dict, containing args used for training. Currently only splits and num_steps are available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a TrainArgs proto message. Default behavior (when splits is empty) is train on `train` split. eval_args: A proto.EvalArgs instance or a dict, containing args used for evaluation. Currently only splits and num_steps are available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a EvalArgs proto message. Default behavior (when splits is empty) is evaluate on `eval` split. custom_config: A dict which contains addtional training job parameters that will be passed into user module. custom_executor_spec: Optional custom executor spec. Deprecated (no compatibility guarantee), please customize component directly. Raises: ValueError: - When both or neither of 'module_file' and user function (e.g., trainer_fn and run_fn) is supplied. - When both or neither of 'examples' and 'transformed_examples' is supplied. - When 'transformed_examples' is supplied but 'transform_graph' is not supplied. """ if [bool(module_file), bool(run_fn), bool(trainer_fn)].count(True) != 1: raise ValueError( "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be " "supplied.") if bool(examples) == bool(transformed_examples): raise ValueError( "Exactly one of 'example' or 'transformed_example' must be supplied." ) if transformed_examples and not transform_graph: raise ValueError("If 'transformed_examples' is supplied, " "'transform_graph' must be supplied too.") if custom_executor_spec: logging.warning( "`custom_executor_spec` is deprecated. Please customize component directly." ) if transformed_examples: logging.warning( "`transformed_examples` is deprecated. Please use `examples` instead." ) examples = examples or transformed_examples model = types.Channel(type=standard_artifacts.Model) model_run = types.Channel(type=standard_artifacts.ModelRun) spec = standard_component_specs.TrainerSpec( examples=examples, transform_graph=transform_graph, schema=schema, base_model=base_model, hyperparameters=hyperparameters, train_args=train_args or trainer_pb2.TrainArgs(), eval_args=eval_args or trainer_pb2.EvalArgs(), module_file=module_file, run_fn=run_fn, trainer_fn=trainer_fn, custom_config=json_utils.dumps(custom_config), model=model, model_run=model_run) super(Trainer, self).__init__(spec=spec, custom_executor_spec=custom_executor_spec) if udf_utils.should_package_user_modules(): # In this case, the `MODULE_PATH_KEY` execution property will be injected # as a reference to the given user module file after packaging, at which # point the `MODULE_FILE_KEY` execution property will be removed. udf_utils.add_user_module_dependency( self, standard_component_specs.MODULE_FILE_KEY, standard_component_specs.MODULE_PATH_KEY)
def __init__(self, examples: types.Channel = None, schema: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, module_file: Optional[Text] = None, tuner_fn: Optional[Text] = None, train_args: trainer_pb2.TrainArgs = None, eval_args: trainer_pb2.EvalArgs = None, tune_args: Optional[tuner_pb2.TuneArgs] = None, custom_config: Optional[Dict[Text, Any]] = None): """Construct a Tuner component. Args: examples: A Channel of type `standard_artifacts.Examples`, serving as the source of examples that are used in tuning (required). schema: An optional Channel of type `standard_artifacts.Schema`, serving as the schema of training and eval data. This is used when raw examples are provided. transform_graph: An optional Channel of type `standard_artifacts.TransformGraph`, serving as the input transform graph if present. This is used when transformed examples are provided. module_file: A path to python module file containing UDF tuner definition. The module_file must implement a function named `tuner_fn` at its top level. The function must have the following signature. def tuner_fn(fn_args: FnArgs) -> TunerFnResult: Exactly one of 'module_file' or 'tuner_fn' must be supplied. tuner_fn: A python path to UDF model definition function. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'tuner_fn' must be supplied. train_args: A trainer_pb2.TrainArgs instance, containing args used for training. Currently only splits and num_steps are available. Default behavior (when splits is empty) is train on `train` split. eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval. Currently only splits and num_steps are available. Default behavior (when splits is empty) is evaluate on `eval` split. tune_args: A tuner_pb2.TuneArgs instance, containing args used for tuning. Currently only num_parallel_trials is available. custom_config: A dict which contains addtional training job parameters that will be passed into user module. """ if bool(module_file) == bool(tuner_fn): raise ValueError( "Exactly one of 'module_file' or 'tuner_fn' must be supplied") best_hyperparameters = types.Channel( type=standard_artifacts.HyperParameters) spec = standard_component_specs.TunerSpec( examples=examples, schema=schema, transform_graph=transform_graph, module_file=module_file, tuner_fn=tuner_fn, train_args=train_args, eval_args=eval_args, tune_args=tune_args, best_hyperparameters=best_hyperparameters, custom_config=json_utils.dumps(custom_config), ) super(Tuner, self).__init__(spec=spec) if udf_utils.should_package_user_modules(): # In this case, the `MODULE_PATH_KEY` execution property will be injected # as a reference to the given user module file after packaging, at which # point the `MODULE_FILE_KEY` execution property will be removed. udf_utils.add_user_module_dependency( self, standard_component_specs.MODULE_FILE_KEY, standard_component_specs.MODULE_PATH_KEY)
def __init__( self, examples: types.Channel = None, model: types.Channel = None, baseline_model: Optional[types.Channel] = None, # TODO(b/148618405): deprecate feature_slicing_spec. feature_slicing_spec: Optional[Union[ evaluator_pb2.FeatureSlicingSpec, Dict[Text, Any]]] = None, fairness_indicator_thresholds: Optional[List[Union[ float, data_types.RuntimeParameter]]] = None, example_splits: Optional[List[Text]] = None, eval_config: Optional[tfma.EvalConfig] = None, schema: Optional[types.Channel] = None, module_file: Optional[Text] = None, module_path: Optional[Text] = None): """Construct an Evaluator component. Args: examples: A Channel of type `standard_artifacts.Examples`, usually produced by an ExampleGen component. _required_ model: A Channel of type `standard_artifacts.Model`, usually produced by a Trainer component. baseline_model: An optional channel of type 'standard_artifacts.Model' as the baseline model for model diff and model validation purpose. feature_slicing_spec: Deprecated, please use eval_config instead. Only support estimator. [evaluator_pb2.FeatureSlicingSpec](https://github.com/tensorflow/tfx/blob/master/tfx/proto/evaluator.proto) instance that describes how Evaluator should slice the data. If any field is provided as a RuntimeParameter, feature_slicing_spec should be constructed as a dict with the same field names as FeatureSlicingSpec proto message. fairness_indicator_thresholds: Optional list of float (or RuntimeParameter) threshold values for use with TFMA fairness indicators. Experimental functionality: this interface and functionality may change at any time. TODO(b/142653905): add a link to additional documentation for TFMA fairness indicators here. example_splits: Names of splits on which the metrics are computed. Default behavior (when example_splits is set to None or Empty) is using the 'eval' split. eval_config: Instance of tfma.EvalConfig containg configuration settings for running the evaluation. This config has options for both estimator and Keras. schema: A `Schema` channel to use for TFXIO. module_file: A path to python module file containing UDFs for Evaluator customization. This functionality is experimental and may change at any time. The module_file can implement following functions at its top level. def custom_eval_shared_model( eval_saved_model_path, model_name, eval_config, **kwargs, ) -> tfma.EvalSharedModel: def custom_extractors( eval_shared_model, eval_config, tensor_adapter_config, ) -> List[tfma.extractors.Extractor]: module_path: A python path to the custom module that contains the UDFs. See 'module_file' for the required signature of UDFs. This functionality is experimental and this API may change at any time. Note this can not be set together with module_file. """ if bool(module_file) and bool(module_path): raise ValueError( 'Python module path can not be set together with module file path.' ) if eval_config is not None and feature_slicing_spec is not None: raise ValueError( "Exactly one of 'eval_config' or 'feature_slicing_spec' " "must be supplied.") if eval_config is None and feature_slicing_spec is None: feature_slicing_spec = evaluator_pb2.FeatureSlicingSpec() logging.info( 'Neither eval_config nor feature_slicing_spec is passed, ' 'the model is treated as estimator.') if feature_slicing_spec: logging.warning('feature_slicing_spec is deprecated, please use ' 'eval_config instead.') blessing = types.Channel(type=standard_artifacts.ModelBlessing) evaluation = types.Channel(type=standard_artifacts.ModelEvaluation) spec = standard_component_specs.EvaluatorSpec( examples=examples, model=model, baseline_model=baseline_model, feature_slicing_spec=feature_slicing_spec, fairness_indicator_thresholds=fairness_indicator_thresholds, example_splits=json_utils.dumps(example_splits), evaluation=evaluation, eval_config=eval_config, blessing=blessing, schema=schema, module_file=module_file, module_path=module_path) super(Evaluator, self).__init__(spec=spec) if udf_utils.should_package_user_modules(): # In this case, the `MODULE_PATH_KEY` execution property will be injected # as a reference to the given user module file after packaging, at which # point the `MODULE_FILE_KEY` execution property will be removed. udf_utils.add_user_module_dependency( self, standard_component_specs.MODULE_FILE_KEY, standard_component_specs.MODULE_PATH_KEY)