Esempio n. 1
0
def load_model(filepath, custom_objects=None, compile=True, options=None):  # pylint: disable=redefined-builtin
    """Loads a model saved via `model.save()`.

  Usage:

  >>> model = tf.keras.Sequential([
  ...     tf.keras.layers.Dense(5, input_shape=(3,)),
  ...     tf.keras.layers.Softmax()])
  >>> model.save('/tmp/model')
  >>> loaded_model = tf.keras.models.load_model('/tmp/model')
  >>> x = tf.random.uniform((10, 3))
  >>> assert np.allclose(model.predict(x), loaded_model.predict(x))

  Note that the model weights may have different scoped names after being
  loaded. Scoped names include the model/layer names, such as
  `"dense_1/kernel:0"`. It is recommended that you use the layer properties to
  access specific variables, e.g. `model.get_layer("dense_1").kernel`.

  Args:
      filepath: One of the following:
          - String or `pathlib.Path` object, path to the saved model
          - `h5py.File` object from which to load the model
      custom_objects: Optional dictionary mapping names
          (strings) to custom classes or functions to be
          considered during deserialization.
      compile: Boolean, whether to compile the model
          after loading.
      options: Optional `tf.saved_model.LoadOptions` object that specifies
        options for loading from SavedModel.

  Returns:
      A Keras model instance. If the original model was compiled, and saved with
      the optimizer, then the returned model will be compiled. Otherwise, the
      model will be left uncompiled. In the case that an uncompiled model is
      returned, a warning is displayed if the `compile` argument is set to
      `True`.

  Raises:
      ImportError: if loading from an hdf5 file and h5py is not available.
      IOError: In case of an invalid savefile.
  """
    with generic_utils.SharedObjectLoadingScope():
        with generic_utils.CustomObjectScope(custom_objects or {}):
            with load_context.load_context(options):
                if (h5py is not None and (isinstance(filepath, h5py.File)
                                          or h5py.is_hdf5(filepath))):
                    return hdf5_format.load_model_from_hdf5(
                        filepath, custom_objects, compile)

                filepath = path_to_string(filepath)
                if isinstance(filepath, six.string_types):
                    loader_impl.parse_saved_model(filepath)
                    return saved_model_load.load(filepath, compile, options)

    raise IOError(
        'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
        'available) or SavedModel.')
 def test_shared_object_loading_scope_returns_shared_obj(self):
     obj_id = 1
     obj = MaybeSharedObject()
     with generic_utils.SharedObjectLoadingScope() as scope:
         scope.set(obj_id, obj)
         self.assertIs(scope.get(obj_id), obj)
Esempio n. 3
0
    def load(
        self,
        *,
        timestamp: Optional[Timestamp] = None,
        compile_model: bool = False,
        custom_objects: Optional[Mapping[str, Any]] = None,
        input_shape: Optional[Tuple[int, ...]] = None,
    ) -> tf.keras.Model:
        """
        Load a Tensorflow model from a TileDB array.

        :param timestamp: Range of timestamps to load fragments of the array which live
            in the specified time range.
        :param compile_model: Whether to compile the model after loading or not.
        :param custom_objects: Mapping of names to custom classes or functions to be
            considered during deserialization.
        :param input_shape: The shape that the custom model expects as input
        :return: Tensorflow model.
        """
        # TODO: Change timestamp when issue in core is resolved

        with tiledb.open(self.uri, ctx=self.ctx,
                         timestamp=timestamp) as model_array:
            model_array_results = model_array[:]
            model_config = json.loads(model_array.meta["model_config"])
            model_class = model_config["class_name"]

            if model_class != "Sequential" and model_class != "Functional":
                with generic_utils.SharedObjectLoadingScope():
                    with generic_utils.CustomObjectScope(custom_objects or {}):
                        if hasattr(model_config, "decode"):
                            model_config = model_config.decode("utf-8")
                        model = model_config_lib.model_from_config(
                            model_config, custom_objects=custom_objects)
                        if not model.built:
                            model.build(input_shape)

                        # Load weights for layers
                        self._load_custom_subclassed_model(model, model_array)
            else:
                cls = (tf.keras.Sequential
                       if model_class == "Sequential" else tf.keras.Model)
                model = cls.from_config(model_config["config"])
                model_weights = pickle.loads(
                    model_array_results["model_weights"].item(0))
                model.set_weights(model_weights)

            if compile_model:
                optimizer_weights = pickle.loads(
                    model_array_results["optimizer_weights"].item(0))
                training_config = json.loads(
                    model_array.meta["training_config"])

                # Compile model.
                model.compile(**saving_utils.compile_args_from_training_config(
                    training_config, custom_objects))
                saving_utils.try_build_compiled_arguments(model)

                # Set optimizer weights.
                if optimizer_weights:
                    try:
                        model.optimizer._create_all_weights(
                            model.trainable_variables)
                    except (NotImplementedError, AttributeError):
                        logging.warning(
                            "Error when creating the weights of optimizer {}, making it "
                            "impossible to restore the saved optimizer state. As a result, "
                            "your model is starting with a freshly initialized optimizer."
                        )

                    try:
                        model.optimizer.set_weights(optimizer_weights)
                    except ValueError:
                        logging.warning("Error in loading the saved optimizer "
                                        "state. As a result, your model is "
                                        "starting with a freshly initialized "
                                        "optimizer.")
            return model