def test_save_load_hdf5_pathlib(self): if sys.version_info < (3, 6): self.skipTest( 'pathlib is only available for python version >= 3.6') path = pathlib.Path(self.get_temp_dir()) / 'model' save.save_model(self.model, path, save_format='h5') save.load_model(path)
def test_save_format_defaults_pathlib(self): if sys.version_info < (3, 6): self.skipTest( 'pathlib is only available for python version >= 3.6') path = pathlib.Path(self.get_temp_dir()) / 'model_path' save.save_model(self.model, path) self.assert_saved_model(path)
def test_save_hdf5(self): path = os.path.join(self.get_temp_dir(), 'model') save.save_model(self.model, path, save_format='h5') self.assert_h5_format(path) with self.assertRaisesRegex( NotImplementedError, 'requires the model to be a Functional model or a Sequential model.'): save.save_model(self.subclassed_model, path, save_format='h5')
def test_layer_saving_with_h5(self): vocab_data = ["earth", "wind", "and", "fire"] input_data = keras.Input(shape=(None, ), dtype=tf.string) layer = get_layer_class()(max_tokens=10) layer.set_vocabulary(vocab_data) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) path = os.path.join(self.get_temp_dir(), "model") with self.assertRaisesRegex(NotImplementedError, "Save or restore weights that is not.*"): save.save_model(model, path, save_format="h5")
def test_save_tf(self): path = os.path.join(self.get_temp_dir(), 'model') save.save_model(self.model, path, save_format='tf') self.assert_saved_model(path) with self.assertRaisesRegex(ValueError, 'input shapes have not been set'): save.save_model(self.subclassed_model, path, save_format='tf') self.subclassed_model.predict(np.random.random((3, 5))) save.save_model(self.subclassed_model, path, save_format='tf') self.assert_saved_model(path)
def test_save_tf(self): path = os.path.join(self.get_temp_dir(), 'model') save.save_model(self.model, path, save_format='tf') self.assert_saved_model(path) with self.assertRaisesRegex( ValueError, r'Model.*cannot be saved.*as opposed to `model.call\(\).*'): save.save_model(self.subclassed_model, path, save_format='tf') self.subclassed_model.predict(np.random.random((3, 5))) save.save_model(self.subclassed_model, path, save_format='tf') self.assert_saved_model(path)
def _save_model(self, model, saved_dir): save.save_model(model, saved_dir, save_format='tf')
def test_save_load_hdf5_pathlib(self): path = pathlib.Path(self.get_temp_dir()) / 'model' save.save_model(self.model, path, save_format='h5') save.load_model(path)
def test_save_format_defaults_pathlib(self): path = pathlib.Path(self.get_temp_dir()) / 'model_path' save.save_model(self.model, path) self.assert_saved_model(path)
def test_save_format_defaults(self): path = os.path.join(self.get_temp_dir(), 'model_path') save.save_model(self.model, path) self.assert_saved_model(path)
def test_save_load_tf_string(self): path = os.path.join(self.get_temp_dir(), 'model') save.save_model(self.model, path, save_format='tf') save.load_model(path)