Пример #1
0
 def make_model():
     _x = Input((10, ))
     _y = Dense(10)(_x)
     _m = Model(_x, _y)
     _m.compile('adam', 'mean_squared_error')
     _m._make_train_function()
     return _m
Пример #2
0
def load_optimizer_weights(model: Model, load_path: pathlib.Path) -> None:
    """
    Load the optimizer states from a tf.keras model saved with
    tf.keras.models.save_model(). Ignores and prints a warning message when
    encountering a graph network. This implementation is lifted from
    tf.keras.models.load_model().
    """
    f = h5py.File(str(load_path), mode="r")
    if "optimizer_weights" in f:
        # Build train function (to get weight updates).  Models that aren't
        # graph networks must wait until they are called with data to
        # _make_train_function() and so can't load optimizer weights.
        if model._is_graph_network:  # pylint: disable=protected-access
            model._make_train_function()
            optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f)
            try:
                model.optimizer.set_weights(optimizer_weight_values)
            except ValueError:
                logging.warning("Error in loading the saved optimizer "
                                "state. As a result, your model is "
                                "starting with a freshly initialized "
                                "optimizer.")
        else:
            logging.warning("Sequential models without an `input_shape` "
                            "passed to the first layer cannot reload their "
                            "optimizer state. As a result, your model is "
                            "starting with a freshly initialized optimizer.")
Пример #3
0
def load_optimizer_weights(
    model: Model, h5group: Any, optimizer: tf.keras.optimizers.Optimizer
) -> None:
    """
    Load the optimizer states from a tf.keras model saved with
    tf.keras.models.save_model(). Ignores and prints a warning message when
    encountering a graph network. This implementation is lifted from
    tf.keras.models.load_model().
    """
    tf2_2_or_newer = version.parse(tf.__version__) >= version.parse("2.2.0")
    if model._is_graph_network or tf2_2_or_newer:  # pylint: disable=protected-access
        if tf2_2_or_newer:
            try:
                optimizer._create_all_weights(model.trainable_variables)
            except (NotImplementedError, AttributeError):
                logging.warning(
                    "Error when creating the weights of optimizer, making it "
                    "impossible to restore the saved optimizer state. As a result, "
                    "your model is starting with a freshly initialized optimizer."
                )
        else:
            # Build train function (to get weight updates).  Models that aren't
            # graph networks must wait until they are called with data to
            # _make_train_function() and so can't load optimizer weights.
            model._make_train_function()

        optimizer_weight_values = load_optimizer_weights_from_hdf5_group(h5group)
        try:
            optimizer.set_weights(optimizer_weight_values)
        except ValueError:
            logging.warning(
                "Error in loading the saved optimizer "
                "state. As a result, your model is "
                "starting with a freshly initialized "
                "optimizer."
            )
    else:
        logging.warning(
            "Sequential models without an `input_shape` "
            "passed to the first layer cannot reload their "
            "optimizer state. As a result, your model is "
            "starting with a freshly initialized optimizer."
        )