def test_saving_preserve_unbuilt_state(self): temp_dir = os.path.join(self.get_temp_dir(), "my_model") subclassed_model = CustomModelX() subclassed_model._save_new(temp_dir) loaded_model = saving_lib.load(temp_dir) self.assertFalse(subclassed_model.built) self.assertFalse(loaded_model.built)
def test_functional_model_with_tf_op_lambda_layer(self): class ToString: def __init__(self): self.contents = "" def __call__(self, msg): self.contents += msg + "\n" temp_dir = os.path.join(self.get_temp_dir(), "my_model") inputs = keras.layers.Input(shape=(32, )) outputs = keras.layers.Dense(1)(inputs) outputs = outputs + inputs functional_model = keras.Model(inputs, outputs) functional_to_string = ToString() functional_model.summary(print_fn=functional_to_string) functional_model.compile(optimizer="adam", loss="mse", metrics=["mae"]) x = np.random.random((1000, 32)) y = np.random.random((1000, 1)) functional_model.fit(x, y, epochs=3) functional_model._save_new(temp_dir) loaded_model = saving_lib.load(temp_dir) loaded_to_string = ToString() loaded_model.summary(print_fn=loaded_to_string) self.assertEqual(functional_to_string.contents, loaded_to_string.contents)
def test_saving_after_compile_but_before_fit(self): temp_dir = os.path.join(self.get_temp_dir(), "my_model") subclassed_model = self._get_subclassed_model() subclassed_model._save_new(temp_dir) # This is so that we can register another function with the same custom # object key, and make sure the newly registered function is used while # loading. del generic_utils._GLOBAL_CUSTOM_OBJECTS[ "my_custom_package>my_mean_squared_error" ] @keras.utils.generic_utils.register_keras_serializable( package="my_custom_package" ) def my_mean_squared_error(y_true, y_pred): """Function-local `mean_squared_error`.""" return backend.mean( tf.math.squared_difference(y_pred, y_true), axis=-1 ) loaded_model = saving_lib.load(temp_dir) # Everything should be the same class or function for the original model # and the loaded model. for model in [subclassed_model, loaded_model]: self.assertIs( model.optimizer.__class__, keras.optimizers.optimizer_v2.adam.Adam, ) self.assertIs( model.compiled_loss.__class__, keras.engine.compile_utils.LossesContainer, ) self.assertEqual(model.compiled_loss._losses[0], "mse") self.assertIs( model.compiled_loss._losses[1], keras.losses.mean_squared_error ) self.assertIs( model.compiled_loss._losses[2].__class__, keras.losses.MeanSquaredError, ) self.assertIs( model.compiled_loss._total_loss_mean.__class__, keras.metrics.base_metric.Mean, ) # Except for a custom function used because the loaded model is supposed # to be using the newly registered custom function. self.assertIs( subclassed_model.compiled_loss._losses[3], module_my_mean_squared_error, ) self.assertIs( loaded_model.compiled_loss._losses[3], my_mean_squared_error ) self.assertIsNot(module_my_mean_squared_error, my_mean_squared_error)
def test_saving_after_fit(self): temp_dir = os.path.join(self.get_temp_dir(), "my_model") subclassed_model = self._get_subclassed_model() x = np.random.random((100, 32)) y = np.random.random((100, 1)) subclassed_model.fit(x, y, epochs=1) subclassed_model._save_new(temp_dir) loaded_model = saving_lib.load(temp_dir) io_utils.enable_interactive_logging() # `tf.print` writes to stderr. This is to make sure the custom training # step is used. with self.captureWritesToStream(sys.stderr) as printed: loaded_model.fit(x, y, epochs=1) self.assertRegex(printed.contents(), train_step_message) # Check that the custom classes do get used. self.assertIsInstance(loaded_model, CustomModelX) self.assertIsInstance(loaded_model.dense1, MyDense) # Check that the custom method is available. self.assertEqual(loaded_model.one(), 1) self.assertEqual(loaded_model.dense1.two(), 2) # Everything should be the same class or function for the original model # and the loaded model. for model in [subclassed_model, loaded_model]: self.assertIs( model.optimizer.__class__, keras.optimizers.optimizer_v2.adam.Adam, ) self.assertIs( model.compiled_loss.__class__, keras.engine.compile_utils.LossesContainer, ) self.assertIs( model.compiled_loss._losses[0].__class__, keras.losses.LossFunctionWrapper, ) self.assertIs( model.compiled_loss._losses[1].__class__, keras.losses.LossFunctionWrapper, ) self.assertIs( model.compiled_loss._losses[2].__class__, keras.losses.MeanSquaredError, ) self.assertIs( model.compiled_loss._losses[3].__class__, keras.losses.LossFunctionWrapper, ) self.assertIs( model.compiled_loss._total_loss_mean.__class__, keras.metrics.base_metric.Mean, )
def test_saving_preserve_built_state(self): temp_dir = os.path.join(self.get_temp_dir(), "my_model") subclassed_model = self._get_subclassed_model() x = np.random.random((100, 32)) y = np.random.random((100, 1)) subclassed_model.fit(x, y, epochs=1) subclassed_model._save_new(temp_dir) loaded_model = saving_lib.load(temp_dir) self.assertTrue(subclassed_model.built) self.assertTrue(loaded_model.built) self.assertEqual(subclassed_model._build_input_shape, loaded_model._build_input_shape) self.assertEqual(tf.TensorShape([None, 32]), loaded_model._build_input_shape)
def test_functional_model_with_tf_op_lambda_layer(self, layer): class ToString: def __init__(self): self.contents = "" def __call__(self, msg): self.contents += msg + "\n" temp_dir = os.path.join(self.get_temp_dir(), "my_model") if layer == "lambda": func = tf.function(lambda x: tf.math.cos(x) + tf.math.sin(x)) inputs = keras.layers.Input(shape=(32,)) outputs = keras.layers.Dense(1)(inputs) outputs = keras.layers.Lambda(func._python_function)(outputs) elif layer == "tf_op_lambda": inputs = keras.layers.Input(shape=(32,)) outputs = keras.layers.Dense(1)(inputs) outputs = outputs + inputs functional_model = keras.Model(inputs, outputs) functional_to_string = ToString() functional_model.summary(print_fn=functional_to_string) functional_model.compile(optimizer="adam", loss="mse", metrics=["mae"]) x = np.random.random((1000, 32)) y = np.random.random((1000, 1)) functional_model.fit(x, y, epochs=3) functional_model._save_new(temp_dir) loaded_model = saving_lib.load(temp_dir) loaded_model.fit(x, y, epochs=3) loaded_to_string = ToString() loaded_model.summary(print_fn=loaded_to_string) # Confirming the original and saved/loaded model have same structure. self.assertEqual( functional_to_string.contents, loaded_to_string.contents )