Пример #1
0
def load(path, compile=True):  # pylint: disable=redefined-builtin
    """Loads Keras objects from a SavedModel.

  Any Keras layer or model saved to the SavedModel will be loaded back
  as Keras objects. Other objects are loaded as regular trackable objects (same
  as `tf.saved_model.load`).

  Currently, Keras saving/loading only retains the Keras object's weights,
  losses, and call function.

  The loaded model can be re-compiled, but the original optimizer, compiled loss
  functions, and metrics are not retained. This is temporary, and `model.save`
  will soon be able to serialize compiled models.

  Args:
    path: Path to SavedModel.
    compile: If true, compile the model after loading it.

  Returns:
    Object loaded from SavedModel.
  """
    # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
    # TODO(kathywu): Add code to load from objects that contain all endpoints
    model = tf_load.load_internal(path, loader_cls=KerasObjectLoader)

    if isinstance(model, RevivedModel) and compile:
        # TODO(kathywu): Use compiled objects from SavedModel, instead of
        # creating new objects from the training config.
        if model._training_config is not None:  # pylint: disable=protected-access
            model.compile(**saving_utils.compile_args_from_training_config(
                model._training_config))  # pylint: disable=protected-access

    return model
    def __init__(self, saved_model_dir):
        """Init method for SavedModelLoader.

    Args:
      saved_model_dir: A SavedModel directory providing a transform graph.  The
        MetaGraphDef and signature are selected from the SavedModel using keys
        defined in `../constants.py` ('transform' and 'transform_signature',
        respectively).
    """
        if tf.version.VERSION < '2.5':
            self._imported = load.load_internal(saved_model_dir,
                                                loader_cls=_Loader)
            if isinstance(self._imported, dict):
                self._imported = self._imported['root']
        else:
            # TODO(b/160294509): Stop using tf.compat.v2 when TF1.15 support is
            # dropped.
            self._imported = tf.compat.v2.saved_model.load(saved_model_dir)
        self.load_v2_in_compat = (constants.TRANSFORM_SIGNATURE
                                  in self._imported.signatures)
        if self.load_v2_in_compat:
            self._wrapped = self._imported.signatures[
                constants.TRANSFORM_SIGNATURE]
            self._func_graph = self._wrapped.graph
            self._structured_inputs = self._get_input_signature_from_v1_saved_model(
                saved_model_dir)
            self._structured_outputs = self._wrapped.structured_outputs
        else:
            # TODO(b/160550490): Remove local import.
            from tensorflow_transform import tf2_utils  # pylint: disable=g-import-not-at-top

            # Since `input_signature` was specified when exporting the tf function to
            # transform_fn is now a ConcreteFunction, but was a tf.function. We need
            # to handle both to maintain backward compatiblity. If it's a tf.function,
            # since `input_signature` was specified when exporting the tf function to
            # `SavedModel`, there should be exactly one concrete function present on
            # loading the `SavedModel`.
            if hasattr(self._imported.transform_fn, 'concrete_functions'):
                concrete_functions = self._imported.transform_fn.concrete_functions
                assert len(concrete_functions) == 1, concrete_functions
                self._wrapped = concrete_functions[0]
            else:
                self._wrapped = self._imported.transform_fn
            self._func_graph = self._wrapped.graph
            self._structured_inputs = (
                tf2_utils.get_structured_inputs_from_func_graph(
                    self._func_graph))
            self._structured_outputs = tf.nest.pack_sequence_as(
                self._func_graph.structured_outputs,
                self._func_graph.outputs,
                expand_composites=True)
        self._output_to_inputs_map = (self._get_output_to_inputs_map(
            self._structured_outputs))
        saved_transform_io._maybe_register_addon_ops()  # pylint: disable=protected-access
Пример #3
0
def load(path, compile=True, options=None):  # pylint: disable=redefined-builtin
    """Loads Keras objects from a SavedModel.

  Any Keras layer or model saved to the SavedModel will be loaded back
  as Keras objects. Other objects are loaded as regular trackable objects (same
  as `tf.saved_model.load`).

  Currently, Keras saving/loading only retains the Keras object's weights,
  losses, and call function.

  The loaded model can be re-compiled, but the original optimizer, compiled loss
  functions, and metrics are not retained. This is temporary, and `model.save`
  will soon be able to serialize compiled models.

  Args:
    path: Path to SavedModel.
    compile: If true, compile the model after loading it.
    options: Optional `tf.saved_model.LoadOptions` object that specifies
      options for loading from SavedModel.


  Returns:
    Object loaded from SavedModel.
  """
    # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
    # TODO(kathywu): Add code to load from objects that contain all endpoints

    model = tf_load.load_internal(path,
                                  options=options,
                                  loader_cls=KerasObjectLoader)

    # pylint: disable=protected-access
    if isinstance(model, training_lib.Model) and compile:
        # TODO(kathywu): Use compiled objects from SavedModel, instead of
        # creating new objects from the training config.
        training_config = model._serialized_attributes['metadata'].get(
            'training_config', None)
        if training_config is not None:
            model.compile(**saving_utils.compile_args_from_training_config(
                training_config))
            saving_utils.try_build_compiled_arguments(model)
        else:
            logging.warning(
                'No training configuration found in save file, so the '
                'model was *not* compiled. Compile it manually.')
    # pylint: enable=protected-access

    # Force variables and resources to initialize.
    if not context.executing_eagerly():
        sess = backend.get_session()  # Variables are initialized by this call.
        sess.run(ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS))

    return model
Пример #4
0
# import tensorflow as tf

from tensorflow.python.keras.saving.saved_model.load import KerasObjectLoader
from tensorflow.python.saved_model.load import load_internal
from tensorflow.python.keras.saving.saved_model.load import RevivedModel
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.saved_model import loader_impl

model_path = 'output/saved_model/cls/1599723701'

loader_impl.parse_saved_model(model_path)
model = load_internal(model_path, tags=['serve'], loader_cls=KerasObjectLoader)

if not isinstance(model, RevivedModel):
    raise RuntimeError("Can not load model")

if model._training_config is None:
    raise RuntimeError("Model _training_config is None")

model.compile(
    **saving_utils.compile_args_from_training_config(model._training_config))

test_data = [[], [], [], []]

model.predict(test_data)