Esempio n. 1
0
 def test_save_load_hdf5_pathlib(self):
     if sys.version_info < (3, 6):
         self.skipTest(
             'pathlib is only available for python version >= 3.6')
     path = pathlib.Path(self.get_temp_dir()) / 'model'
     save.save_model(self.model, path, save_format='h5')
     save.load_model(path)
Esempio n. 2
0
 def test_save_format_defaults_pathlib(self):
     if sys.version_info < (3, 6):
         self.skipTest(
             'pathlib is only available for python version >= 3.6')
     path = pathlib.Path(self.get_temp_dir()) / 'model_path'
     save.save_model(self.model, path)
     self.assert_saved_model(path)
Esempio n. 3
0
 def test_save_hdf5(self):
   path = os.path.join(self.get_temp_dir(), 'model')
   save.save_model(self.model, path, save_format='h5')
   self.assert_h5_format(path)
   with self.assertRaisesRegex(
       NotImplementedError,
       'requires the model to be a Functional model or a Sequential model.'):
     save.save_model(self.subclassed_model, path, save_format='h5')
Esempio n. 4
0
    def test_layer_saving_with_h5(self):
        vocab_data = ["earth", "wind", "and", "fire"]

        input_data = keras.Input(shape=(None, ), dtype=tf.string)
        layer = get_layer_class()(max_tokens=10)
        layer.set_vocabulary(vocab_data)
        int_data = layer(input_data)
        model = keras.Model(inputs=input_data, outputs=int_data)
        path = os.path.join(self.get_temp_dir(), "model")
        with self.assertRaisesRegex(NotImplementedError,
                                    "Save or restore weights that is not.*"):
            save.save_model(model, path, save_format="h5")
Esempio n. 5
0
 def test_save_tf(self):
   path = os.path.join(self.get_temp_dir(), 'model')
   save.save_model(self.model, path, save_format='tf')
   self.assert_saved_model(path)
   with self.assertRaisesRegex(ValueError, 'input shapes have not been set'):
     save.save_model(self.subclassed_model, path, save_format='tf')
   self.subclassed_model.predict(np.random.random((3, 5)))
   save.save_model(self.subclassed_model, path, save_format='tf')
   self.assert_saved_model(path)
Esempio n. 6
0
 def test_save_tf(self):
   path = os.path.join(self.get_temp_dir(), 'model')
   save.save_model(self.model, path, save_format='tf')
   self.assert_saved_model(path)
   with self.assertRaisesRegex(
       ValueError, r'Model.*cannot be saved.*as opposed to `model.call\(\).*'):
     save.save_model(self.subclassed_model, path, save_format='tf')
   self.subclassed_model.predict(np.random.random((3, 5)))
   save.save_model(self.subclassed_model, path, save_format='tf')
   self.assert_saved_model(path)
 def _save_model(self, model, saved_dir):
     save.save_model(model, saved_dir, save_format='tf')
Esempio n. 8
0
 def test_save_load_hdf5_pathlib(self):
   path = pathlib.Path(self.get_temp_dir()) / 'model'
   save.save_model(self.model, path, save_format='h5')
   save.load_model(path)
Esempio n. 9
0
 def test_save_format_defaults_pathlib(self):
   path = pathlib.Path(self.get_temp_dir()) / 'model_path'
   save.save_model(self.model, path)
   self.assert_saved_model(path)
Esempio n. 10
0
 def test_save_format_defaults(self):
   path = os.path.join(self.get_temp_dir(), 'model_path')
   save.save_model(self.model, path)
   self.assert_saved_model(path)
Esempio n. 11
0
 def test_save_load_tf_string(self):
   path = os.path.join(self.get_temp_dir(), 'model')
   save.save_model(self.model, path, save_format='tf')
   save.load_model(path)