def test_process_unlabeled_1(self): vectors_cache_dir = '.cache' if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir) data_cache_path = os.path.join(test_dir_path, 'test_datasets', 'cacheddata.pth') if os.path.exists(data_cache_path): os.remove(data_cache_path) vec_dir = os.path.abspath(os.path.join(test_dir_path, 'test_datasets')) filename = 'fasttext_sample.vec.zip' url_base = urljoin('file:', pathname2url(vec_dir)) + os.path.sep ft = FastText(filename, url_base=url_base, cache=vectors_cache_dir) train, valid, test = process(path=os.path.join(test_dir_path, 'test_datasets'), train='test_train.csv', validation='test_valid.csv', test='test_test.csv', id_attr='id', ignore_columns=('left_id', 'right_id'), embeddings=ft, embeddings_cache_path='', pca=True) model_save_path = 'sif_model.pth' model = MatchingModel(attr_summarizer='sif') model.run_train(train, valid, epochs=1, batch_size=8, best_save_path=model_save_path, pos_neg_ratio=3) test_unlabeled = process_unlabeled( path=os.path.join(test_dir_path, 'test_datasets', 'test_test.csv'), trained_model=model, ignore_columns=('left_id', 'right_id')) self.assertEqual(test_unlabeled.all_text_fields, test.all_text_fields) if os.path.exists(model_save_path): os.remove(model_save_path) if os.path.exists(data_cache_path): os.remove(data_cache_path) if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir)
def test_process_unlabeled_1(): vectors_cache_dir = ".cache" if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir) data_cache_path = os.path.join(test_dir_path, "test_datasets", "cacheddata.pth") if os.path.exists(data_cache_path): os.remove(data_cache_path) train, valid, test = process( path=os.path.join(test_dir_path, "test_datasets"), train="test_train.csv", validation="test_valid.csv", test="test_test.csv", id_attr="id", ignore_columns=("left_id", "right_id"), embeddings=embeddings, embeddings_cache_path="", pca=True, ) model_save_path = "sif_model.pth" model = MatchingModel(attr_summarizer="sif") model.run_train( train, valid, epochs=1, batch_size=8, best_save_path=model_save_path, pos_neg_ratio=3, ) test_unlabeled = process_unlabeled( path=os.path.join(test_dir_path, "test_datasets", "test_test.csv"), trained_model=model, ignore_columns=("left_id", "right_id"), ) assert test_unlabeled.all_text_fields == test.all_text_fields if os.path.exists(model_save_path): os.remove(model_save_path) if os.path.exists(data_cache_path): os.remove(data_cache_path) if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir)
def test_hybrid(self): model_save_path = 'hybrid_model.pth' model = MatchingModel(attr_summarizer='hybrid') model.run_train(self.train, self.valid, epochs=1, batch_size=8, best_save_path=model_save_path, pos_neg_ratio=3) unlabeled = process_unlabeled(path=os.path.join( test_dir_path, 'test_datasets', 'test_unlabeled.csv'), trained_model=model, ignore_columns=('left_id', 'right_id')) pred_test = model.run_eval(self.test, return_predictions=True) pred_unlabeled = model.run_prediction(unlabeled) self.assertEqual(sorted([tup[1] for tup in pred_test]), sorted(list(pred_unlabeled['match_score']))) if os.path.exists(model_save_path): os.remove(model_save_path)