def test_save_load_trackable(self, distribution, optimizer): # TODO(b/123533246): Enable the test for TPU once bug is fixed if (isinstance(distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)) and distribution.extended.steps_per_run > 1): self.skipTest('MultiStep TPU Strategy deadlocks with optimizer restore.') with self.cached_session(): dataset = keras_test_lib.get_dataset(distribution) with distribution.scope(): model = keras_test_lib.get_model() model.compile( optimizer(), 'mse') model.fit(dataset, epochs=1, steps_per_epoch=1) weights_file = tempfile.mktemp() model.save_weights(weights_file) model_2 = keras_test_lib.get_model() model_2.compile( optimizer(), 'mse') model_2.load_weights(weights_file) model_2.predict( keras_test_lib.get_predict_dataset(distribution), steps=2) model_2.fit(dataset, epochs=1, steps_per_epoch=1)
def test_callbacks_in_predict(self, distribution, experimental_run_tf_function): with distribution.scope(): model = keras_test_lib.get_model() model.compile( optimizer='sgd', loss='mse', metrics=['mae'], experimental_run_tf_function=experimental_run_tf_function) dataset = keras_test_lib.get_dataset(distribution) counter = Counter() model.predict( keras_test_lib.get_predict_dataset(dataset), steps=5, callbacks=[counter]) self.assertDictEqual( counter.method_counts, { 'on_predict_batch_begin': 5, 'on_predict_batch_end': 5, 'on_predict_begin': 1, 'on_predict_end': 1 })
def test_save_load_h5(self, distribution, optimizer, cloning): with self.cached_session(): dataset = keras_test_lib.get_dataset(distribution) with distribution.scope(): model = keras_test_lib.get_model() model.compile(optimizer(), 'mse', cloning=cloning) model.fit(dataset, epochs=1, steps_per_epoch=1) weights_file = tempfile.mktemp('.h5') model.save_weights(weights_file) model_2 = keras_test_lib.get_model() model_2.compile(optimizer(), 'mse', cloning=cloning) model_2.load_weights(weights_file) model_2.predict( keras_test_lib.get_predict_dataset(distribution), steps=2) model_2.fit(dataset, epochs=1, steps_per_epoch=1)
def test_save_load_h5(self, distribution): with self.cached_session(): dataset = keras_test_lib.get_dataset(distribution) with distribution.scope(): model = keras_test_lib.get_model() model.compile(rms_prop_keras.RMSprop(learning_rate=0.01), 'mse') model.fit(dataset, epochs=1, steps_per_epoch=1) weights_file = tempfile.mktemp('.h5') model.save_weights(weights_file) model_2 = keras_test_lib.get_model() model_2.compile(rms_prop_keras.RMSprop(learning_rate=0.01), 'mse') model_2.load_weights(weights_file) model_2.predict( keras_test_lib.get_predict_dataset(distribution), steps=2) model_2.fit(dataset, epochs=1, steps_per_epoch=1)
def test_callbacks_in_predict(self, distribution): with distribution.scope(): model = keras_test_lib.get_model() model.compile(optimizer='sgd', loss='mse', metrics=['mae']) dataset = keras_test_lib.get_dataset(distribution) counter = Counter() model.predict( keras_test_lib.get_predict_dataset(dataset), steps=5, callbacks=[counter]) self.assertDictEqual( counter.method_counts, { 'on_predict_batch_begin': 5, 'on_predict_batch_end': 5, 'on_predict_begin': 1, 'on_predict_end': 1 })
def test_save_load_trackable(self, distribution, optimizer, cloning): # TODO(b/123533246): Enable the test for TPU once bug is fixed if (isinstance(distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)) and distribution.extended.steps_per_run > 1): self.skipTest('MultiStep TPU Strategy deadlocks with optimizer restore.') with self.cached_session(): dataset = keras_test_lib.get_dataset(distribution) with distribution.scope(): model = keras_test_lib.get_model() model.compile(optimizer(), 'mse', cloning=cloning) model.fit(dataset, epochs=1, steps_per_epoch=1) weights_file = tempfile.mktemp() model.save_weights(weights_file) model_2 = keras_test_lib.get_model() model_2.compile(optimizer(), 'mse', cloning=cloning) model_2.load_weights(weights_file) model_2.predict( keras_test_lib.get_predict_dataset(distribution), steps=2) model_2.fit(dataset, epochs=1, steps_per_epoch=1)