Esempio n. 1
0
    def testSeqtoSeqWithLuongMulAttentionSaveAndLoad(self):
        filename = "full_ml_tests/toy_nmt/toy_nmt_luong_mul_attention.xml"

        block_filepath = self.setup_holder.filepath_handler.get_test_block_path(
            filename)
        data_filepath = self.setup_holder.filepath_handler.get_test_data_path(
            "nmt/toy/")
        embedding_filepath = self.setup_holder.filepath_handler.get_test_data_path(
            "embeddings/")

        interface = BasicInterface()
        interface.load_file(block_filepath)
        interface.set_variable("data_folder", data_filepath)
        interface.set_variable("embedding_folder", embedding_filepath)
        interface.initialize()

        f = open(data_filepath + "tgt.txt")
        lines = [l.strip() for l in f]
        gold_sentences = [l + " EOS" for l in lines]
        f.close()

        interface.train()
        predictions = interface.predict()

        self.assertEqual(len(gold_sentences), len(predictions))

        for i, s in enumerate(gold_sentences):
            pred_sent = " ".join(predictions[i])
            self.assertEqual(s, pred_sent)

        model_filepath = self.setup_holder.filepath_handler.get_test_data_path(
            "stored_models/toy_nmt/luong")
        if os.path.exists(model_filepath):
            shutil.rmtree(model_filepath)

        interface.save(model_filepath)

        interface_2 = BasicInterface()
        interface_2.load_file(block_filepath)
        interface_2.set_variable("data_folder", data_filepath)
        interface_2.set_variable("embedding_folder", embedding_filepath)
        interface_2.initialize()
        interface_2.load(model_filepath)

        predictions = interface_2.predict()

        self.assertEqual(len(gold_sentences), len(predictions))

        for i, s in enumerate(gold_sentences):
            pred_sent = " ".join(predictions[i])
            self.assertEqual(s, pred_sent)
    def testIrisSaveAndLoad(self):
        interface = BasicInterface()

        filename = "iris_tests/full_iris_no_shuffling.xml"
        block_filepath = self.setup_holder.filepath_handler.get_test_block_path(
            filename)
        data_filepath = self.setup_holder.filepath_handler.get_test_block_path(
            "iris_tests")

        model_filepath = self.setup_holder.filepath_handler.get_test_data_path(
            "stored_models/iris/iris.model")
        if os.path.exists(model_filepath):
            shutil.rmtree(model_filepath)

        interface.load_file(block_filepath)
        interface.set_variable("data_folder", data_filepath)
        interface.initialize()

        interface.train()
        performance = interface.evaluate()

        self.assertGreaterEqual(1.0, performance)
        self.assertLess(0.9, performance)

        interface.save(model_filepath)

        interface_2 = BasicInterface()
        interface_2.load_file(block_filepath)
        interface_2.set_variable("data_folder", data_filepath)
        interface_2.initialize()
        interface_2.load(model_filepath)

        performance = interface_2.evaluate()

        self.assertGreaterEqual(1.0, performance)
        self.assertLess(0.9, performance)