예제 #1
0
def test_train_save_load_hybrid_self_attention(datasets):
    model_save_path = "self_att_hybrid_model.pth"
    model = MatchingModel(attr_summarizer=attr_summarizers.Hybrid(
        word_contextualizer="self-attention"))

    model.run_train(
        datasets.train,
        datasets.valid,
        epochs=1,
        batch_size=8,
        best_save_path=model_save_path,
        pos_neg_ratio=3,
    )

    s1 = model.run_eval(datasets.test)

    model2 = MatchingModel(attr_summarizer=attr_summarizers.Hybrid(
        word_contextualizer="self-attention"))
    model2.load_state(model_save_path)
    s2 = model2.run_eval(datasets.test)

    assert s1 == s2

    if os.path.exists(model_save_path):
        os.remove(model_save_path)
예제 #2
0
    def test_process_unlabeled_1(self):
        vectors_cache_dir = '.cache'
        if os.path.exists(vectors_cache_dir):
            shutil.rmtree(vectors_cache_dir)

        data_cache_path = os.path.join(test_dir_path, 'test_datasets',
                                       'cacheddata.pth')
        if os.path.exists(data_cache_path):
            os.remove(data_cache_path)

        vec_dir = os.path.abspath(os.path.join(test_dir_path, 'test_datasets'))
        filename = 'fasttext_sample.vec.zip'
        url_base = urljoin('file:', pathname2url(vec_dir)) + os.path.sep
        ft = FastText(filename, url_base=url_base, cache=vectors_cache_dir)

        train, valid, test = process(path=os.path.join(test_dir_path,
                                                       'test_datasets'),
                                     train='test_train.csv',
                                     validation='test_valid.csv',
                                     test='test_test.csv',
                                     id_attr='id',
                                     ignore_columns=('left_id', 'right_id'),
                                     embeddings=ft,
                                     embeddings_cache_path='',
                                     pca=True)

        model_save_path = 'sif_model.pth'
        model = MatchingModel(attr_summarizer='sif')
        model.run_train(train,
                        valid,
                        epochs=1,
                        batch_size=8,
                        best_save_path=model_save_path,
                        pos_neg_ratio=3)

        test_unlabeled = process_unlabeled(
            path=os.path.join(test_dir_path, 'test_datasets', 'test_test.csv'),
            trained_model=model,
            ignore_columns=('left_id', 'right_id'))

        self.assertEqual(test_unlabeled.all_text_fields, test.all_text_fields)

        if os.path.exists(model_save_path):
            os.remove(model_save_path)

        if os.path.exists(data_cache_path):
            os.remove(data_cache_path)

        if os.path.exists(vectors_cache_dir):
            shutil.rmtree(vectors_cache_dir)
예제 #3
0
def test_process_unlabeled_1():
    vectors_cache_dir = ".cache"
    if os.path.exists(vectors_cache_dir):
        shutil.rmtree(vectors_cache_dir)

    data_cache_path = os.path.join(test_dir_path, "test_datasets",
                                   "cacheddata.pth")
    if os.path.exists(data_cache_path):
        os.remove(data_cache_path)

    train, valid, test = process(
        path=os.path.join(test_dir_path, "test_datasets"),
        train="test_train.csv",
        validation="test_valid.csv",
        test="test_test.csv",
        id_attr="id",
        ignore_columns=("left_id", "right_id"),
        embeddings=embeddings,
        embeddings_cache_path="",
        pca=True,
    )

    model_save_path = "sif_model.pth"
    model = MatchingModel(attr_summarizer="sif")
    model.run_train(
        train,
        valid,
        epochs=1,
        batch_size=8,
        best_save_path=model_save_path,
        pos_neg_ratio=3,
    )

    test_unlabeled = process_unlabeled(
        path=os.path.join(test_dir_path, "test_datasets", "test_test.csv"),
        trained_model=model,
        ignore_columns=("left_id", "right_id"),
    )

    assert test_unlabeled.all_text_fields == test.all_text_fields

    if os.path.exists(model_save_path):
        os.remove(model_save_path)

    if os.path.exists(data_cache_path):
        os.remove(data_cache_path)

    if os.path.exists(vectors_cache_dir):
        shutil.rmtree(vectors_cache_dir)
    def test_rnn(self):
        model_save_path = 'rnn_model.pth'
        model = MatchingModel(attr_summarizer='rnn')
        model.run_train(self.train,
                        self.valid,
                        epochs=1,
                        batch_size=8,
                        best_save_path=model_save_path,
                        pos_neg_ratio=3)
        s1 = model.run_eval(self.test)

        model2 = MatchingModel(attr_summarizer='rnn')
        model2.load_state(model_save_path)
        s2 = model2.run_eval(self.test)

        self.assertEqual(s1, s2)

        if os.path.exists(model_save_path):
            os.remove(model_save_path)
예제 #5
0
def test_train_save_load_rnn(datasets):
    model_save_path = "rnn_model.pth"
    model = MatchingModel(attr_summarizer="rnn")
    model.run_train(
        datasets.train,
        datasets.valid,
        epochs=1,
        batch_size=8,
        best_save_path=model_save_path,
        pos_neg_ratio=3,
    )
    s1 = model.run_eval(datasets.test)

    model2 = MatchingModel(attr_summarizer="rnn")
    model2.load_state(model_save_path)
    s2 = model2.run_eval(datasets.test)

    assert s1 == s2

    if os.path.exists(model_save_path):
        os.remove(model_save_path)
예제 #6
0
def test_predict_unlabeled_hybrid(datasets):
    model_save_path = "hybrid_model.pth"
    model = MatchingModel(attr_summarizer="hybrid")
    model.run_train(
        datasets.train,
        datasets.valid,
        epochs=1,
        batch_size=8,
        best_save_path=model_save_path,
        pos_neg_ratio=3,
    )

    unlabeled = process_unlabeled(
        path=os.path.join(test_dir_path, "test_datasets",
                          "test_unlabeled.csv"),
        trained_model=model,
        ignore_columns=("left_id", "right_id"),
    )

    pred_test = model.run_eval(datasets.test, return_predictions=True)
    pred_unlabeled = model.run_prediction(unlabeled)

    assert sorted(tup[1] for tup in pred_test) == sorted(
        list(pred_unlabeled["match_score"]))

    if os.path.exists(model_save_path):
        os.remove(model_save_path)
    def test_hybrid_self_attention(self):
        model_save_path = 'self_att_hybrid_model.pth'
        model = MatchingModel(attr_summarizer=attr_summarizers.Hybrid(
            word_contextualizer='self-attention'))

        model.run_train(self.train,
                        self.valid,
                        epochs=1,
                        batch_size=8,
                        best_save_path=model_save_path,
                        pos_neg_ratio=3)

        s1 = model.run_eval(self.test)

        model2 = MatchingModel(attr_summarizer=attr_summarizers.Hybrid(
            word_contextualizer='self-attention'))
        model2.load_state(model_save_path)
        s2 = model2.run_eval(self.test)

        self.assertEqual(s1, s2)

        if os.path.exists(model_save_path):
            os.remove(model_save_path)
    def test_hybrid(self):
        model_save_path = 'hybrid_model.pth'
        model = MatchingModel(attr_summarizer='hybrid')
        model.run_train(self.train,
                        self.valid,
                        epochs=1,
                        batch_size=8,
                        best_save_path=model_save_path,
                        pos_neg_ratio=3)

        unlabeled = process_unlabeled(path=os.path.join(
            test_dir_path, 'test_datasets', 'test_unlabeled.csv'),
                                      trained_model=model,
                                      ignore_columns=('left_id', 'right_id'))

        pred_test = model.run_eval(self.test, return_predictions=True)
        pred_unlabeled = model.run_prediction(unlabeled)

        self.assertEqual(sorted([tup[1] for tup in pred_test]),
                         sorted(list(pred_unlabeled['match_score'])))

        if os.path.exists(model_save_path):
            os.remove(model_save_path)