Пример #1
0
 def tearDown(self):
     for file in os.listdir(self.__model_dir):
         os.remove(
             os.path.join(self.__model_dir,
                          file)) if not file.endswith(".hdf5") else None
     shutil.rmtree(self.__resource_tmp)
     Undertest.reset()
Пример #2
0
 def test_load_model_and_weights(self):
     network_old = Undertest.get_network((2, 20), self.__hyperparameters)
     weights_old = network_old._Network__model.get_weights()
     network_new = Undertest.load_model_and_weights(
         "{}/model.hdf5".format(self.__model_dir),
         "{}/weights.hdf5".format(self.__weights_dir),
         self.__hyperparameters)
     weights_new = network_new._Network__model.get_weights()
     self.assertFalse(np.array_equal(weights_old, weights_new))
Пример #3
0
 def test_exception_on_creating_an_unknown_network(self):
     self.__hyperparameters.network_type = "unknown"
     try:
         Undertest.get_network((2, 20), self.__hyperparameters)
     except Exception as e:
         self.assertTrue(isinstance(e, ValueError))
         self.assertTrue("Unknown network type" in str(e))
     else:
         self.fail("Should have thrown exception")
Пример #4
0
 def test_save_model_and_weights(self):
     model_filepath = "{}/{}".format(self.__model_dir, "model.hdf5")
     weights_filepath = "{}/{}".format(self.__weights_dir, "weights.hdf5")
     Undertest.save_model_and_weights(
         model_filepath,
         weights_filepath,
         "{}/{}".format(self.__resource_tmp, "model_combined.hdf5"),
     )
     self.assertEqual(
         (2, 20),
         Undertest.get_from_model(model_filepath,
                                  self.__hyperparameters).input_shape)
Пример #5
0
 def test_exception_on_resume_with_no_previous_training_log(self):
     self.__hyperparameters.epochs = 2
     network = Undertest.get_network((2, 20), self.__hyperparameters)
     model_filepath = "{}/model.hdf5".format(self.__resource_tmp)
     weights_filepath = "{}/weights.hdf5".format(self.__resource_tmp)
     with open(self.__train_data, "rb") as file:
         train_data = pickle.load(file)
     with open(self.__labels, "rb") as file:
         labels = pickle.load(file)
     _, _ = network.fit_and_get_history(
         train_data,
         labels,
         model_filepath,
         weights_filepath,
         self.__resource_tmp,
         "{}/training_1.log".format(self.__resource_tmp),
         False,
     )
     try:
         network.fit_and_get_history(
             train_data,
             labels,
             model_filepath,
             weights_filepath,
             self.__resource_tmp,
             "{}/training_2.log".format(self.__resource_tmp),
             True,
         )
     except Exception as e:
         self.assertTrue(isinstance(e, AssertionError))
         self.assertTrue(
             "does not exist and is required by training resumption" in str(
                 e))
     else:
         self.fail("Should have thrown exception")
Пример #6
0
 def test_exception_on_resume_with_no_extra_epochs(self):
     self.__hyperparameters.epochs = 2
     network = Undertest.get_network((2, 20), self.__hyperparameters)
     model_filepath = "{}/model.hdf5".format(self.__resource_tmp)
     weights_filepath = "{}/weights.hdf5".format(self.__resource_tmp)
     with open(self.__train_data, "rb") as file:
         train_data = pickle.load(file)
     with open(self.__labels, "rb") as file:
         labels = pickle.load(file)
     _, _ = network.fit_and_get_history(
         train_data,
         labels,
         model_filepath,
         weights_filepath,
         self.__resource_tmp,
         "{}/training.log".format(self.__resource_tmp),
         False,
     )
     try:
         network.fit_and_get_history(
             train_data,
             labels,
             model_filepath,
             weights_filepath,
             self.__resource_tmp,
             "{}/training.log".format(self.__resource_tmp),
             True,
         )
     except Exception as e:
         self.assertTrue(isinstance(e, AssertionError))
         self.assertEqual(
             "The existing model has been trained for 2 epochs. Make sure the total epochs are larger than 2",
             str(e))
     else:
         self.fail("Should have thrown exception")
Пример #7
0
 def test_resume_and_get_history(self):
     self.__hyperparameters.epochs = 2
     network = Undertest.get_network((2, 20), self.__hyperparameters)
     model_filepath = "{}/model.hdf5".format(self.__resource_tmp)
     weights_filepath = "{}/weights.hdf5".format(self.__resource_tmp)
     with open(self.__train_data, "rb") as file:
         train_data = pickle.load(file)
     with open(self.__labels, "rb") as file:
         labels = pickle.load(file)
     _, _ = network.fit_and_get_history(
         train_data,
         labels,
         model_filepath,
         weights_filepath,
         self.__resource_tmp,
         "{}/training.log".format(self.__resource_tmp),
         False,
     )
     self.__hyperparameters.epochs = 3
     val_loss, val_acc = network.fit_and_get_history(
         train_data,
         labels,
         model_filepath,
         weights_filepath,
         self.__resource_tmp,
         "{}/training.log".format(self.__resource_tmp),
         True,
     )
     self.assertEqual(list, type(val_loss))
     self.assertEqual(list, type(val_acc))
Пример #8
0
 def test_get_predictions(self):
     network = Undertest.get_from_model(
         "{}/model.hdf5".format(self.__model_dir), self.__hyperparameters)
     with open(self.__train_data, "rb") as file:
         train_data = pickle.load(file)
     self.assertEqual(
         (11431, 1),
         network.get_predictions(
             train_data,
             "{}/weights.hdf5".format(self.__weights_dir)).shape)
Пример #9
0
 def test_simple_fit_with_generator(self):
     self.__hyperparameters.epochs = 3
     with h5py.File(self.__training_dump, "r") as hf:
         val_loss, val_acc = Undertest.simple_fit_with_generator(
             (2, 20),
             hf["train_data"],
             hf["labels"],
             self.__hyperparameters,
         )
         self.assertEqual(list, type(val_loss))
         self.assertEqual(list, type(val_acc))
         self.assertTrue(len(val_loss) == self.__hyperparameters.epochs)
         self.assertTrue(len(val_acc) == self.__hyperparameters.epochs)
Пример #10
0
 def test_simple_fit(self):
     with open(self.__train_data, "rb") as file:
         train_data = pickle.load(file)
     with open(self.__labels, "rb") as file:
         labels = pickle.load(file)
     val_loss, val_acc = Undertest.simple_fit(
         (2, 20),
         train_data,
         labels,
         self.__hyperparameters,
     )
     self.assertEqual(list, type(val_loss))
     self.assertEqual(list, type(val_acc))
Пример #11
0
 def test_fit_with_generator(self):
     self.__hyperparameters.epochs = 3
     network = Undertest.get_network((2, 20), self.__hyperparameters)
     model_filepath = "{}/model.hdf5".format(self.__resource_tmp)
     weights_filepath = "{}/weights.hdf5".format(self.__resource_tmp)
     with h5py.File(self.__training_dump, "r") as hf:
         val_loss, val_acc = network.fit_with_generator(
             hf["train_data"],
             hf["labels"],
             model_filepath,
             weights_filepath,
             self.__resource_tmp,
             "training.log",
             False,
         )
         self.assertEqual(list, type(val_loss))
         self.assertEqual(list, type(val_acc))
         self.assertTrue(len(val_loss) == self.__hyperparameters.epochs)
         self.assertTrue(len(val_acc) == self.__hyperparameters.epochs)
Пример #12
0
 def test_throw_exception_on_fit_with_generator(self, mock_fit):
     self.__hyperparameters.epochs = 3
     network = Undertest.get_network((2, 20), self.__hyperparameters)
     model_filepath = "{}/model.hdf5".format(self.__resource_tmp)
     weights_filepath = "{}/weights.hdf5".format(self.__resource_tmp)
     with h5py.File(self.__training_dump, "r") as hf:
         try:
             network.fit_with_generator(
                 hf["train_data"],
                 hf["labels"],
                 model_filepath,
                 weights_filepath,
                 self.__resource_tmp,
                 "training.log",
                 False,
             )
         except Exception as e:
             self.assertTrue(mock_fit.called)
             self.assertTrue(isinstance(e, TerminalException))
             self.assertTrue("interrupted" in str(e))
         else:
             self.fail("Should have thrown exception")
Пример #13
0
 def test_throw_exception_on_fit_and_get_history(self, mock_fit):
     try:
         network = Undertest.get_network((2, 20), self.__hyperparameters)
         model_filepath = "{}/model.hdf5".format(self.__resource_tmp)
         weights_filepath = "{}/weights.hdf5".format(self.__resource_tmp)
         with open(self.__train_data, "rb") as file:
             train_data = pickle.load(file)
         with open(self.__labels, "rb") as file:
             labels = pickle.load(file)
         network.fit_and_get_history(
             train_data,
             labels,
             model_filepath,
             weights_filepath,
             self.__resource_tmp,
             "training.log",
             False,
         )
     except Exception as e:
         self.assertTrue(mock_fit.called)
         self.assertTrue(isinstance(e, TerminalException))
         self.assertTrue("interrupted" in str(e))
     else:
         self.fail("Should have thrown exception")
Пример #14
0
 def test_summary(self):
     network = Undertest.get_network((2, 20), self.__hyperparameters)
     self.assertTrue(network.summary is None)  # Why this is None
Пример #15
0
 def test_create_conv_1d_network(self):
     self.__hyperparameters.network_type = "conv_1d"
     network = Undertest.get_network((2, 20), self.__hyperparameters)
     self.assertEqual("conv_1d", network.n_type)
Пример #16
0
 def test_create_bi_lstm_network(self):
     self.__hyperparameters.network_type = "bi_lstm"
     network = Undertest.get_network((2, 20), self.__hyperparameters)
     self.assertEqual("bi_lstm", network.n_type)
Пример #17
0
 def test_input_shape(self):
     network = Undertest.get_network((2, 20), self.__hyperparameters)
     self.assertEqual((2, 20), network.input_shape)
Пример #18
0
 def test_get_from_model(self):
     model_filepath = "{}/{}".format(self.__model_dir, "model.hdf5")
     network = Undertest.get_from_model(model_filepath,
                                        self.__hyperparameters)
     self.assertEqual((2, 20), network.input_shape)
     self.assertEqual("unknown", network.n_type)
Пример #19
0
 def test_layers(self):
     network = Undertest.get_network((2, 20), self.__hyperparameters)
     self.assertEqual(16, len(network.layers))