Пример #1
0
    def __init__(self,
                 tft_output: TFTransformOutput,
                 exported_as_v1: Optional[bool] = None):
        super().__init__(trainable=False)
        self._tft_output = tft_output
        if exported_as_v1 is None:
            self._exported_as_v1 = saved_transform_io.exported_as_v1(
                tft_output.transform_savedmodel_dir)
        else:
            self._exported_as_v1 = exported_as_v1
        self._saved_model_loader_value = None
        self._loaded_saved_model_graph = None
        # TODO(b/160294509): Use tf.compat.v1 when we stop supporting TF 1.15.
        if ops.executing_eagerly_outside_functions():
            _check_tensorflow_version()
            # The model must be tracked by assigning to an attribute of the Keras
            # layer. Hence, we track the attributes of _saved_model_loader here as
            # well.
            self._saved_model_loader_tracked_dict = self._saved_model_loader.__dict__

        # TODO(b/162055065): This is needed because otherwise we'd get an error in
        # some cases:
        # ValueError: Your Layer or Model is in an invalid state. This can happen
        # if you are interleaving estimator/non-estimator models or interleaving
        # models/layers made in tf.compat.v1.Graph.as_default() with models/layers
        # created outside of it. Converting a model to an estimator (via
        # model_to_estimator) invalidates all models/layers made before the
        # conversion (even if they were not the model converted to an estimator).
        # Similarly, making a layer or a model inside a a tf.compat.v1.Graph
        # invalidates all layers/models you previously made outside of the graph.
        self._originally_built_as_v1 = True
Пример #2
0
    def _exported_as_v1(self) -> bool:
        """A boolean.

    Indicates whether the SavedModel was exported using TF 1.x or TF 2.x APIs.
    """
        if self._exported_as_v1_value is None:
            self._exported_as_v1_value = saved_transform_io.exported_as_v1(
                self.transform_savedmodel_dir)
        return self._exported_as_v1_value
Пример #3
0
    def load_transform_graph(self):
        """Load the transform graph without replacing any placeholders.

    This is necessary to ensure that variables in the transform graph are
    included in the training checkpoint when using tf.Estimator.  This should
    be called in the training input_fn.
    """
        if self._exported_as_v1 is None:
            self._exported_as_v1 = saved_transform_io.exported_as_v1(
                self.transform_savedmodel_dir)

        if self._exported_as_v1:
            saved_transform_io.partially_apply_saved_transform_internal(
                self.transform_savedmodel_dir, {})
        else:
            # Note: This should use the same mechanism as `transform_raw_features` to
            # load the SavedModel into the current graph context.
            _ = self.transform_features_layer()({})