def make_model(): _x = Input((10, )) _y = Dense(10)(_x) _m = Model(_x, _y) _m.compile('adam', 'mean_squared_error') _m._make_train_function() return _m
def load_optimizer_weights(model: Model, load_path: pathlib.Path) -> None: """ Load the optimizer states from a tf.keras model saved with tf.keras.models.save_model(). Ignores and prints a warning message when encountering a graph network. This implementation is lifted from tf.keras.models.load_model(). """ f = h5py.File(str(load_path), mode="r") if "optimizer_weights" in f: # Build train function (to get weight updates). Models that aren't # graph networks must wait until they are called with data to # _make_train_function() and so can't load optimizer weights. if model._is_graph_network: # pylint: disable=protected-access model._make_train_function() optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f) try: model.optimizer.set_weights(optimizer_weight_values) except ValueError: logging.warning("Error in loading the saved optimizer " "state. As a result, your model is " "starting with a freshly initialized " "optimizer.") else: logging.warning("Sequential models without an `input_shape` " "passed to the first layer cannot reload their " "optimizer state. As a result, your model is " "starting with a freshly initialized optimizer.")
def load_optimizer_weights( model: Model, h5group: Any, optimizer: tf.keras.optimizers.Optimizer ) -> None: """ Load the optimizer states from a tf.keras model saved with tf.keras.models.save_model(). Ignores and prints a warning message when encountering a graph network. This implementation is lifted from tf.keras.models.load_model(). """ tf2_2_or_newer = version.parse(tf.__version__) >= version.parse("2.2.0") if model._is_graph_network or tf2_2_or_newer: # pylint: disable=protected-access if tf2_2_or_newer: try: optimizer._create_all_weights(model.trainable_variables) except (NotImplementedError, AttributeError): logging.warning( "Error when creating the weights of optimizer, making it " "impossible to restore the saved optimizer state. As a result, " "your model is starting with a freshly initialized optimizer." ) else: # Build train function (to get weight updates). Models that aren't # graph networks must wait until they are called with data to # _make_train_function() and so can't load optimizer weights. model._make_train_function() optimizer_weight_values = load_optimizer_weights_from_hdf5_group(h5group) try: optimizer.set_weights(optimizer_weight_values) except ValueError: logging.warning( "Error in loading the saved optimizer " "state. As a result, your model is " "starting with a freshly initialized " "optimizer." ) else: logging.warning( "Sequential models without an `input_shape` " "passed to the first layer cannot reload their " "optimizer state. As a result, your model is " "starting with a freshly initialized optimizer." )