Пример #1
0
    def testAddModuleDependencyAndPackage(self):
        # Do not test packaging in unsupported environments.
        if not udf_utils.should_package_user_modules():
            return

        # Create a component with a testing user module file.
        temp_dir = tempfile.mkdtemp()
        temp_module_file = os.path.join(temp_dir, 'my_user_module.py')
        with open(temp_module_file, 'w') as f:
            f.write('# Test user module file.\nEXPOSED_VALUE="ABC123xyz"')
        component = _MyComponent(spec=_MyComponentSpec(
            my_module_file=temp_module_file))

        # Add the user module file pip dependency.
        udf_utils.add_user_module_dependency(component, 'my_module_file',
                                             'my_module_path')
        self.assertLen(component._pip_dependencies, 1)
        dependency = component._pip_dependencies[0]
        self.assertIsInstance(dependency,
                              udf_utils.UserModuleFilePipDependency)
        self.assertIs(dependency.component, component)
        self.assertEqual(dependency.module_file_key, 'my_module_file')
        self.assertEqual(dependency.module_path_key, 'my_module_path')

        # Resolve the pip dependency and package the user module.
        temp_pipeline_root = tempfile.mkdtemp()
        component._resolve_pip_dependencies(temp_pipeline_root)
        self.assertLen(component._pip_dependencies, 1)
        dependency = component._pip_dependencies[0]

        # The hash version is based on the module names and contents and thus
        # should be stable.
        self.assertEqual(
            dependency,
            os.path.join(
                temp_pipeline_root, '_wheels', 'tfx_user_code_MyComponent-0.0+'
                '1c9b861db85cc54c56a56cbf64f77c1b9d1ded487d60a97d082ead6b250ee62c'
                '-py3-none-any.whl'))

        # Test import behavior within context manager.
        with udf_utils.TempPipInstallContext([dependency]):
            # Test import from same process.
            import my_user_module  # pylint: disable=g-import-not-at-top
            self.assertEqual(my_user_module.EXPOSED_VALUE, 'ABC123xyz')
            del sys.modules['my_user_module']

            # Test import from a subprocess.
            self.assertEqual(
                subprocess.check_output([
                    sys.executable, '-c',
                    'import my_user_module; print(my_user_module.EXPOSED_VALUE)'
                ]), b'ABC123xyz\n')

        # Test that the import paths are cleaned up, so the user module can no
        # longer be imported.
        with self.assertRaises(ModuleNotFoundError):
            import my_user_module  # pylint: disable=g-import-not-at-top
Пример #2
0
    def __init__(
            self,
            examples: types.Channel,
            schema: types.Channel,
            module_file: Optional[Union[Text,
                                        data_types.RuntimeParameter]] = None,
            preprocessing_fn: Optional[Union[
                Text, data_types.RuntimeParameter]] = None,
            splits_config: Optional[transform_pb2.SplitsConfig] = None,
            analyzer_cache: Optional[types.Channel] = None,
            materialize: bool = True,
            disable_analyzer_cache: bool = False,
            force_tf_compat_v1: bool = False,
            custom_config: Optional[Dict[Text, Any]] = None,
            disable_statistics: bool = False):
        """Construct a Transform component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples` (required).
        This should contain custom splits specified in splits_config. If
        custom split is not provided, this should contain two splits 'train'
        and 'eval'.
      schema: A Channel of type `standard_artifacts.Schema`. This should
        contain a single schema artifact.
      module_file: The file path to a python module file, from which the
        'preprocessing_fn' function will be loaded.
        Exactly one of 'module_file' or 'preprocessing_fn' must be supplied.

        The function needs to have the following signature:
        ```
        def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]:
          ...
        ```
        where the values of input and returned Dict are either tf.Tensor or
        tf.SparseTensor.

        If additional inputs are needed for preprocessing_fn, they can be passed
        in custom_config:

        ```
        def preprocessing_fn(inputs: Dict[Text, Any], custom_config:
                             Dict[Text, Any]) -> Dict[Text, Any]:
          ...
        ```
        Use of a RuntimeParameter for this argument is experimental.
      preprocessing_fn: The path to python function that implements a
        'preprocessing_fn'. See 'module_file' for expected signature of the
        function. Exactly one of 'module_file' or 'preprocessing_fn' must be
        supplied. Use of a RuntimeParameter for this argument is experimental.
      splits_config: A transform_pb2.SplitsConfig instance, providing splits
        that should be analyzed and splits that should be transformed. Note
        analyze and transform splits can have overlap. Default behavior (when
        splits_config is not set) is analyze the 'train' split and transform
        all splits. If splits_config is set, analyze cannot be empty.
      analyzer_cache: Optional input 'TransformCache' channel containing
        cached information from previous Transform runs. When provided,
        Transform will try use the cached calculation if possible.
      materialize: If True, write transformed examples as an output.
      disable_analyzer_cache: If False, Transform will use input cache if
        provided and write cache output. If True, `analyzer_cache` must not be
        provided.
      force_tf_compat_v1: (Optional) If True and/or TF2 behaviors are disabled
        Transform will use Tensorflow in compat.v1 mode irrespective of
        installed version of Tensorflow. Defaults to `False`.
      custom_config: A dict which contains additional parameters that will be
        passed to preprocessing_fn.
      disable_statistics: If True, do not invoke TFDV to compute pre-transform
        and post-transform statistics. When statistics are computed, they will
        will be stored in the `pre_transform_feature_stats/` and
        `post_transform_feature_stats/` subfolders of the `transform_graph`
        export.

    Raises:
      ValueError: When both or neither of 'module_file' and 'preprocessing_fn'
        is supplied.
    """
        if bool(module_file) == bool(preprocessing_fn):
            raise ValueError(
                "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied."
            )

        transform_graph = types.Channel(type=standard_artifacts.TransformGraph)
        transformed_examples = None
        if materialize:
            transformed_examples = types.Channel(
                type=standard_artifacts.Examples)
            transformed_examples.matching_channel_name = 'examples'

        (pre_transform_schema, pre_transform_stats, post_transform_schema,
         post_transform_stats, post_transform_anomalies) = (None, ) * 5
        if not disable_statistics:
            pre_transform_schema = types.Channel(
                type=standard_artifacts.Schema)
            post_transform_schema = types.Channel(
                type=standard_artifacts.Schema)
            pre_transform_stats = types.Channel(
                type=standard_artifacts.ExampleStatistics)
            post_transform_stats = types.Channel(
                type=standard_artifacts.ExampleStatistics)
            post_transform_anomalies = types.Channel(
                type=standard_artifacts.ExampleAnomalies)

        if disable_analyzer_cache:
            updated_analyzer_cache = None
            if analyzer_cache:
                raise ValueError(
                    '`analyzer_cache` is set when disable_analyzer_cache is True.'
                )
        else:
            updated_analyzer_cache = types.Channel(
                type=standard_artifacts.TransformCache)

        spec = standard_component_specs.TransformSpec(
            examples=examples,
            schema=schema,
            module_file=module_file,
            preprocessing_fn=preprocessing_fn,
            force_tf_compat_v1=int(force_tf_compat_v1),
            splits_config=splits_config,
            transform_graph=transform_graph,
            transformed_examples=transformed_examples,
            analyzer_cache=analyzer_cache,
            updated_analyzer_cache=updated_analyzer_cache,
            custom_config=json_utils.dumps(custom_config),
            disable_statistics=int(disable_statistics),
            pre_transform_schema=pre_transform_schema,
            pre_transform_stats=pre_transform_stats,
            post_transform_schema=post_transform_schema,
            post_transform_stats=post_transform_stats,
            post_transform_anomalies=post_transform_anomalies)
        super(Transform, self).__init__(spec=spec)

        if udf_utils.should_package_user_modules():
            # In this case, the `MODULE_PATH_KEY` execution property will be injected
            # as a reference to the given user module file after packaging, at which
            # point the `MODULE_FILE_KEY` execution property will be removed.
            udf_utils.add_user_module_dependency(
                self, standard_component_specs.MODULE_FILE_KEY,
                standard_component_specs.MODULE_PATH_KEY)
Пример #3
0
    def __init__(
            self,
            examples: Optional[types.Channel] = None,
            transformed_examples: Optional[types.Channel] = None,
            transform_graph: Optional[types.Channel] = None,
            schema: Optional[types.Channel] = None,
            base_model: Optional[types.Channel] = None,
            hyperparameters: Optional[types.Channel] = None,
            module_file: Optional[Union[Text,
                                        data_types.RuntimeParameter]] = None,
            run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None,
            # TODO(b/147702778): deprecate trainer_fn.
            trainer_fn: Optional[Union[Text,
                                       data_types.RuntimeParameter]] = None,
            train_args: Optional[Union[trainer_pb2.TrainArgs,
                                       Dict[Text, Any]]] = None,
            eval_args: Optional[Union[trainer_pb2.EvalArgs, Dict[Text,
                                                                 Any]]] = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None):
        """Construct a Trainer component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, serving as
        the source of examples used in training (required). May be raw or
        transformed.
      transformed_examples: Deprecated (no compatibility guarantee). Please set
        'examples' instead.
      transform_graph: An optional Channel of type
        `standard_artifacts.TransformGraph`, serving as the input transform
        graph if present.
      schema:  An optional Channel of type `standard_artifacts.Schema`, serving
        as the schema of training and eval data. Schema is optional when
        1) transform_graph is provided which contains schema.
        2) user module bypasses the usage of schema, e.g., hardcoded.
      base_model: A Channel of type `Model`, containing model that will be used
        for training. This can be used for warmstart, transfer learning or
        model ensembling.
      hyperparameters: A Channel of type `standard_artifacts.HyperParameters`,
        serving as the hyperparameters for training module. Tuner's output best
        hyperparameters can be feed into this.
      module_file: A path to python module file containing UDF model definition.
        The module_file must implement a function named `run_fn` at its top
        level with function signature:
          `def run_fn(trainer.fn_args_utils.FnArgs)`,
        and the trained model must be saved to FnArgs.serving_model_dir when
        this function is executed.

        For Estimator based Executor, The module_file must implement a function
        named `trainer_fn` at its top level. The function must have the
        following signature.
          def trainer_fn(trainer.fn_args_utils.FnArgs,
                         tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
            ...
          where the returned Dict has the following key-values.
            'estimator': an instance of tf.estimator.Estimator
            'train_spec': an instance of tf.estimator.TrainSpec
            'eval_spec': an instance of tf.estimator.EvalSpec
            'eval_input_receiver_fn': an instance of tfma EvalInputReceiver.
        Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer
        uses GenericExecutor (default). Use of a RuntimeParameter for this
        argument is experimental.
      run_fn:  A python path to UDF model definition function for generic
        trainer. See 'module_file' for details. Exactly one of 'module_file' or
        'run_fn' must be supplied if Trainer uses GenericExecutor (default).
         Use of a RuntimeParameter for this argument is experimental.
      trainer_fn:  A python path to UDF model definition function for estimator
        based trainer. See 'module_file' for the required signature of the UDF.
        Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer
        uses Estimator based Executor. Use of a RuntimeParameter for this
        argument is experimental.
      train_args: A proto.TrainArgs instance or a dict, containing args
        used for training. Currently only splits and num_steps are available. If
        it's provided as a dict and any field is a RuntimeParameter, it should
        have the same field names as a TrainArgs proto message. Default
        behavior (when splits is empty) is train on `train` split.
      eval_args: A proto.EvalArgs instance or a dict, containing args
        used for evaluation. Currently only splits and num_steps are available.
        If it's provided as a dict and any field is a RuntimeParameter, it
        should have the same field names as a EvalArgs proto message. Default
        behavior (when splits is empty) is evaluate on `eval` split.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
      custom_executor_spec: Optional custom executor spec. Deprecated (no
        compatibility guarantee), please customize component directly.

    Raises:
      ValueError:
        - When both or neither of 'module_file' and user function
          (e.g., trainer_fn and run_fn) is supplied.
        - When both or neither of 'examples' and 'transformed_examples'
            is supplied.
        - When 'transformed_examples' is supplied but 'transform_graph'
            is not supplied.
    """
        if [bool(module_file),
                bool(run_fn),
                bool(trainer_fn)].count(True) != 1:
            raise ValueError(
                "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be "
                "supplied.")

        if bool(examples) == bool(transformed_examples):
            raise ValueError(
                "Exactly one of 'example' or 'transformed_example' must be supplied."
            )

        if transformed_examples and not transform_graph:
            raise ValueError("If 'transformed_examples' is supplied, "
                             "'transform_graph' must be supplied too.")

        if custom_executor_spec:
            logging.warning(
                "`custom_executor_spec` is deprecated. Please customize component directly."
            )
        if transformed_examples:
            logging.warning(
                "`transformed_examples` is deprecated. Please use `examples` instead."
            )
        examples = examples or transformed_examples
        model = types.Channel(type=standard_artifacts.Model)
        model_run = types.Channel(type=standard_artifacts.ModelRun)
        spec = standard_component_specs.TrainerSpec(
            examples=examples,
            transform_graph=transform_graph,
            schema=schema,
            base_model=base_model,
            hyperparameters=hyperparameters,
            train_args=train_args or trainer_pb2.TrainArgs(),
            eval_args=eval_args or trainer_pb2.EvalArgs(),
            module_file=module_file,
            run_fn=run_fn,
            trainer_fn=trainer_fn,
            custom_config=json_utils.dumps(custom_config),
            model=model,
            model_run=model_run)
        super(Trainer,
              self).__init__(spec=spec,
                             custom_executor_spec=custom_executor_spec)

        if udf_utils.should_package_user_modules():
            # In this case, the `MODULE_PATH_KEY` execution property will be injected
            # as a reference to the given user module file after packaging, at which
            # point the `MODULE_FILE_KEY` execution property will be removed.
            udf_utils.add_user_module_dependency(
                self, standard_component_specs.MODULE_FILE_KEY,
                standard_component_specs.MODULE_PATH_KEY)
Пример #4
0
    def __init__(self,
                 examples: types.Channel = None,
                 schema: Optional[types.Channel] = None,
                 transform_graph: Optional[types.Channel] = None,
                 module_file: Optional[Text] = None,
                 tuner_fn: Optional[Text] = None,
                 train_args: trainer_pb2.TrainArgs = None,
                 eval_args: trainer_pb2.EvalArgs = None,
                 tune_args: Optional[tuner_pb2.TuneArgs] = None,
                 custom_config: Optional[Dict[Text, Any]] = None):
        """Construct a Tuner component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, serving as the
        source of examples that are used in tuning (required).
      schema:  An optional Channel of type `standard_artifacts.Schema`, serving
        as the schema of training and eval data. This is used when raw examples
        are provided.
      transform_graph: An optional Channel of type
        `standard_artifacts.TransformGraph`, serving as the input transform
        graph if present. This is used when transformed examples are provided.
      module_file: A path to python module file containing UDF tuner definition.
        The module_file must implement a function named `tuner_fn` at its top
        level. The function must have the following signature.
            def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
        Exactly one of 'module_file' or 'tuner_fn' must be supplied.
      tuner_fn:  A python path to UDF model definition function. See
        'module_file' for the required signature of the UDF. Exactly one of
        'module_file' or 'tuner_fn' must be supplied.
      train_args: A trainer_pb2.TrainArgs instance, containing args used for
        training. Currently only splits and num_steps are available. Default
        behavior (when splits is empty) is train on `train` split.
      eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval.
        Currently only splits and num_steps are available. Default behavior
        (when splits is empty) is evaluate on `eval` split.
      tune_args: A tuner_pb2.TuneArgs instance, containing args used for tuning.
        Currently only num_parallel_trials is available.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
    """
        if bool(module_file) == bool(tuner_fn):
            raise ValueError(
                "Exactly one of 'module_file' or 'tuner_fn' must be supplied")

        best_hyperparameters = types.Channel(
            type=standard_artifacts.HyperParameters)
        spec = standard_component_specs.TunerSpec(
            examples=examples,
            schema=schema,
            transform_graph=transform_graph,
            module_file=module_file,
            tuner_fn=tuner_fn,
            train_args=train_args,
            eval_args=eval_args,
            tune_args=tune_args,
            best_hyperparameters=best_hyperparameters,
            custom_config=json_utils.dumps(custom_config),
        )
        super(Tuner, self).__init__(spec=spec)

        if udf_utils.should_package_user_modules():
            # In this case, the `MODULE_PATH_KEY` execution property will be injected
            # as a reference to the given user module file after packaging, at which
            # point the `MODULE_FILE_KEY` execution property will be removed.
            udf_utils.add_user_module_dependency(
                self, standard_component_specs.MODULE_FILE_KEY,
                standard_component_specs.MODULE_PATH_KEY)
Пример #5
0
    def __init__(
            self,
            examples: types.Channel = None,
            model: types.Channel = None,
            baseline_model: Optional[types.Channel] = None,
            # TODO(b/148618405): deprecate feature_slicing_spec.
            feature_slicing_spec: Optional[Union[
                evaluator_pb2.FeatureSlicingSpec, Dict[Text, Any]]] = None,
            fairness_indicator_thresholds: Optional[List[Union[
                float, data_types.RuntimeParameter]]] = None,
            example_splits: Optional[List[Text]] = None,
            eval_config: Optional[tfma.EvalConfig] = None,
            schema: Optional[types.Channel] = None,
            module_file: Optional[Text] = None,
            module_path: Optional[Text] = None):
        """Construct an Evaluator component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, usually
        produced by an ExampleGen component. _required_
      model: A Channel of type `standard_artifacts.Model`, usually produced by
        a Trainer component.
      baseline_model: An optional channel of type 'standard_artifacts.Model' as
        the baseline model for model diff and model validation purpose.
      feature_slicing_spec:
        Deprecated, please use eval_config instead. Only support estimator.
        [evaluator_pb2.FeatureSlicingSpec](https://github.com/tensorflow/tfx/blob/master/tfx/proto/evaluator.proto)
          instance that describes how Evaluator should slice the data. If any
          field is provided as a RuntimeParameter, feature_slicing_spec should
          be constructed as a dict with the same field names as
          FeatureSlicingSpec proto message.
      fairness_indicator_thresholds: Optional list of float (or
        RuntimeParameter) threshold values for use with TFMA fairness
          indicators. Experimental functionality: this interface and
          functionality may change at any time. TODO(b/142653905): add a link
          to additional documentation for TFMA fairness indicators here.
      example_splits: Names of splits on which the metrics are computed.
        Default behavior (when example_splits is set to None or Empty) is using
        the 'eval' split.
      eval_config: Instance of tfma.EvalConfig containg configuration settings
        for running the evaluation. This config has options for both estimator
        and Keras.
      schema: A `Schema` channel to use for TFXIO.
      module_file: A path to python module file containing UDFs for Evaluator
        customization. This functionality is experimental and may change at any
        time. The module_file can implement following functions at its top
        level.
          def custom_eval_shared_model(
             eval_saved_model_path, model_name, eval_config, **kwargs,
          ) -> tfma.EvalSharedModel:
          def custom_extractors(
            eval_shared_model, eval_config, tensor_adapter_config,
          ) -> List[tfma.extractors.Extractor]:
      module_path: A python path to the custom module that contains the UDFs.
        See 'module_file' for the required signature of UDFs. This functionality
        is experimental and this API may change at any time. Note this can
        not be set together with module_file.
    """
        if bool(module_file) and bool(module_path):
            raise ValueError(
                'Python module path can not be set together with module file path.'
            )

        if eval_config is not None and feature_slicing_spec is not None:
            raise ValueError(
                "Exactly one of 'eval_config' or 'feature_slicing_spec' "
                "must be supplied.")
        if eval_config is None and feature_slicing_spec is None:
            feature_slicing_spec = evaluator_pb2.FeatureSlicingSpec()
            logging.info(
                'Neither eval_config nor feature_slicing_spec is passed, '
                'the model is treated as estimator.')

        if feature_slicing_spec:
            logging.warning('feature_slicing_spec is deprecated, please use '
                            'eval_config instead.')

        blessing = types.Channel(type=standard_artifacts.ModelBlessing)
        evaluation = types.Channel(type=standard_artifacts.ModelEvaluation)
        spec = standard_component_specs.EvaluatorSpec(
            examples=examples,
            model=model,
            baseline_model=baseline_model,
            feature_slicing_spec=feature_slicing_spec,
            fairness_indicator_thresholds=fairness_indicator_thresholds,
            example_splits=json_utils.dumps(example_splits),
            evaluation=evaluation,
            eval_config=eval_config,
            blessing=blessing,
            schema=schema,
            module_file=module_file,
            module_path=module_path)
        super(Evaluator, self).__init__(spec=spec)

        if udf_utils.should_package_user_modules():
            # In this case, the `MODULE_PATH_KEY` execution property will be injected
            # as a reference to the given user module file after packaging, at which
            # point the `MODULE_FILE_KEY` execution property will be removed.
            udf_utils.add_user_module_dependency(
                self, standard_component_specs.MODULE_FILE_KEY,
                standard_component_specs.MODULE_PATH_KEY)