def testSeqtoSeqWithLuongMulAttentionSaveAndLoad(self): filename = "full_ml_tests/toy_nmt/toy_nmt_luong_mul_attention.xml" block_filepath = self.setup_holder.filepath_handler.get_test_block_path( filename) data_filepath = self.setup_holder.filepath_handler.get_test_data_path( "nmt/toy/") embedding_filepath = self.setup_holder.filepath_handler.get_test_data_path( "embeddings/") interface = BasicInterface() interface.load_file(block_filepath) interface.set_variable("data_folder", data_filepath) interface.set_variable("embedding_folder", embedding_filepath) interface.initialize() f = open(data_filepath + "tgt.txt") lines = [l.strip() for l in f] gold_sentences = [l + " EOS" for l in lines] f.close() interface.train() predictions = interface.predict() self.assertEqual(len(gold_sentences), len(predictions)) for i, s in enumerate(gold_sentences): pred_sent = " ".join(predictions[i]) self.assertEqual(s, pred_sent) model_filepath = self.setup_holder.filepath_handler.get_test_data_path( "stored_models/toy_nmt/luong") if os.path.exists(model_filepath): shutil.rmtree(model_filepath) interface.save(model_filepath) interface_2 = BasicInterface() interface_2.load_file(block_filepath) interface_2.set_variable("data_folder", data_filepath) interface_2.set_variable("embedding_folder", embedding_filepath) interface_2.initialize() interface_2.load(model_filepath) predictions = interface_2.predict() self.assertEqual(len(gold_sentences), len(predictions)) for i, s in enumerate(gold_sentences): pred_sent = " ".join(predictions[i]) self.assertEqual(s, pred_sent)
def testIrisSaveAndLoad(self): interface = BasicInterface() filename = "iris_tests/full_iris_no_shuffling.xml" block_filepath = self.setup_holder.filepath_handler.get_test_block_path( filename) data_filepath = self.setup_holder.filepath_handler.get_test_block_path( "iris_tests") model_filepath = self.setup_holder.filepath_handler.get_test_data_path( "stored_models/iris/iris.model") if os.path.exists(model_filepath): shutil.rmtree(model_filepath) interface.load_file(block_filepath) interface.set_variable("data_folder", data_filepath) interface.initialize() interface.train() performance = interface.evaluate() self.assertGreaterEqual(1.0, performance) self.assertLess(0.9, performance) interface.save(model_filepath) interface_2 = BasicInterface() interface_2.load_file(block_filepath) interface_2.set_variable("data_folder", data_filepath) interface_2.initialize() interface_2.load(model_filepath) performance = interface_2.evaluate() self.assertGreaterEqual(1.0, performance) self.assertLess(0.9, performance)