예제 #1
0
def test_train_save_load_hybrid_self_attention(datasets):
    model_save_path = "self_att_hybrid_model.pth"
    model = MatchingModel(attr_summarizer=attr_summarizers.Hybrid(
        word_contextualizer="self-attention"))

    model.run_train(
        datasets.train,
        datasets.valid,
        epochs=1,
        batch_size=8,
        best_save_path=model_save_path,
        pos_neg_ratio=3,
    )

    s1 = model.run_eval(datasets.test)

    model2 = MatchingModel(attr_summarizer=attr_summarizers.Hybrid(
        word_contextualizer="self-attention"))
    model2.load_state(model_save_path)
    s2 = model2.run_eval(datasets.test)

    assert s1 == s2

    if os.path.exists(model_save_path):
        os.remove(model_save_path)
    def test_rnn(self):
        model_save_path = 'rnn_model.pth'
        model = MatchingModel(attr_summarizer='rnn')
        model.run_train(self.train,
                        self.valid,
                        epochs=1,
                        batch_size=8,
                        best_save_path=model_save_path,
                        pos_neg_ratio=3)
        s1 = model.run_eval(self.test)

        model2 = MatchingModel(attr_summarizer='rnn')
        model2.load_state(model_save_path)
        s2 = model2.run_eval(self.test)

        self.assertEqual(s1, s2)

        if os.path.exists(model_save_path):
            os.remove(model_save_path)
예제 #3
0
def test_train_save_load_rnn(datasets):
    model_save_path = "rnn_model.pth"
    model = MatchingModel(attr_summarizer="rnn")
    model.run_train(
        datasets.train,
        datasets.valid,
        epochs=1,
        batch_size=8,
        best_save_path=model_save_path,
        pos_neg_ratio=3,
    )
    s1 = model.run_eval(datasets.test)

    model2 = MatchingModel(attr_summarizer="rnn")
    model2.load_state(model_save_path)
    s2 = model2.run_eval(datasets.test)

    assert s1 == s2

    if os.path.exists(model_save_path):
        os.remove(model_save_path)
    def test_hybrid_self_attention(self):
        model_save_path = 'self_att_hybrid_model.pth'
        model = MatchingModel(attr_summarizer=attr_summarizers.Hybrid(
            word_contextualizer='self-attention'))

        model.run_train(self.train,
                        self.valid,
                        epochs=1,
                        batch_size=8,
                        best_save_path=model_save_path,
                        pos_neg_ratio=3)

        s1 = model.run_eval(self.test)

        model2 = MatchingModel(attr_summarizer=attr_summarizers.Hybrid(
            word_contextualizer='self-attention'))
        model2.load_state(model_save_path)
        s2 = model2.run_eval(self.test)

        self.assertEqual(s1, s2)

        if os.path.exists(model_save_path):
            os.remove(model_save_path)