def test_qa_train_effectiveness(): """ Ensures that HappyQuestionAnswering.train() results in lowering the loss as determined by HappyQuestionAnswering.eval() """ # use a non-fine-tuned model so we DEFINITELY get an improvement happy_qa = HappyQuestionAnswering() before_loss = happy_qa.eval("../data/qa/train-eval.csv").loss happy_qa.train("../data/qa/train-eval.csv") after_loss = happy_qa.eval("../data/qa/train-eval.csv").loss assert after_loss < before_loss
def test_tc_with_dataclass(): happy_qa = HappyQuestionAnswering() train_args = QATrainArgs(learning_rate=0.01, num_train_epochs=1) happy_qa.train("../data/qa/train-eval.csv", args=train_args) eval_args = QAEvalArgs() result_eval = happy_qa.eval("../data/qa/train-eval.csv", args=eval_args) assert type(result_eval.loss) == float test_args = QATestArgs() result_test = happy_qa.test("../data/qa/test.csv", args=test_args) assert type(result_test[0].answer) == str
def test_qa_with_dic(): happy_qa = HappyQuestionAnswering() train_args = {'learning_rate': 0.01, "num_train_epochs": 1} happy_qa.train("../data/qa/train-eval.csv", args=train_args) eval_args = {} result_eval = happy_qa.eval("../data/qa/train-eval.csv", args=eval_args) assert type(result_eval.loss) == float test_args = {} result_test = happy_qa.test("../data/qa/test.csv", args=test_args) assert type(result_test[0].answer) == str
def main(): # Be careful not to commit the csv files to the rep train_csv_path = "train.csv" eval_csv_path = "eval.csv" train_dataset = load_dataset('squad', split='train[0:499]') eval_dataset = load_dataset('squad', split='validation[0:99]') generate_csv(train_csv_path, train_dataset) generate_csv(eval_csv_path, eval_dataset) happy_qa = HappyQuestionAnswering(model_type="BERT", model_name="bert-base-uncased") before_loss = happy_qa.eval(eval_csv_path) happy_qa.train(train_csv_path) after_loss = happy_qa.eval(eval_csv_path) print("Before loss: ", before_loss.loss) print("After loss: ", after_loss.loss)
def test_qa_train(): happy_qa = HappyQuestionAnswering( model_type='DISTILBERT', model_name='distilbert-base-cased-distilled-squad') result = happy_qa.train("../data/qa/train-eval.csv")