def test_predict_unlabeled_hybrid(datasets): model_save_path = "hybrid_model.pth" model = MatchingModel(attr_summarizer="hybrid") model.run_train( datasets.train, datasets.valid, epochs=1, batch_size=8, best_save_path=model_save_path, pos_neg_ratio=3, ) unlabeled = process_unlabeled( path=os.path.join(test_dir_path, "test_datasets", "test_unlabeled.csv"), trained_model=model, ignore_columns=("left_id", "right_id"), ) pred_test = model.run_eval(datasets.test, return_predictions=True) pred_unlabeled = model.run_prediction(unlabeled) assert sorted(tup[1] for tup in pred_test) == sorted( list(pred_unlabeled["match_score"])) if os.path.exists(model_save_path): os.remove(model_save_path)
def test_hybrid(self): model_save_path = 'hybrid_model.pth' model = MatchingModel(attr_summarizer='hybrid') model.run_train(self.train, self.valid, epochs=1, batch_size=8, best_save_path=model_save_path, pos_neg_ratio=3) unlabeled = process_unlabeled(path=os.path.join( test_dir_path, 'test_datasets', 'test_unlabeled.csv'), trained_model=model, ignore_columns=('left_id', 'right_id')) pred_test = model.run_eval(self.test, return_predictions=True) pred_unlabeled = model.run_prediction(unlabeled) self.assertEqual(sorted([tup[1] for tup in pred_test]), sorted(list(pred_unlabeled['match_score']))) if os.path.exists(model_save_path): os.remove(model_save_path)