Beispiel #1
0
def test_qa_train_effectiveness():
    """
    Ensures that HappyQuestionAnswering.train() results in
    lowering the loss as determined by HappyQuestionAnswering.eval()
    """
    # use a non-fine-tuned model so we DEFINITELY get an improvement
    happy_qa = HappyQuestionAnswering()
    before_loss = happy_qa.eval("../data/qa/train-eval.csv").loss
    happy_qa.train("../data/qa/train-eval.csv")
    after_loss = happy_qa.eval("../data/qa/train-eval.csv").loss

    assert after_loss < before_loss
def main():
    # Be careful not to commit the csv files to the rep
    train_csv_path = "train.csv"
    eval_csv_path = "eval.csv"

    train_dataset = load_dataset('squad', split='train[0:499]')
    eval_dataset = load_dataset('squad', split='validation[0:99]')

    generate_csv(train_csv_path, train_dataset)
    generate_csv(eval_csv_path, eval_dataset)

    happy_qa = HappyQuestionAnswering(model_type="BERT", model_name="bert-base-uncased")
    before_loss = happy_qa.eval(eval_csv_path)
    happy_qa.train(train_csv_path)
    after_loss = happy_qa.eval(eval_csv_path)

    print("Before loss: ", before_loss.loss)
    print("After loss: ", after_loss.loss)
Beispiel #3
0
def test_tc_with_dataclass():

    happy_qa = HappyQuestionAnswering()
    train_args = QATrainArgs(learning_rate=0.01, num_train_epochs=1)

    happy_qa.train("../data/qa/train-eval.csv", args=train_args)

    eval_args = QAEvalArgs()

    result_eval = happy_qa.eval("../data/qa/train-eval.csv", args=eval_args)
    assert type(result_eval.loss) == float

    test_args = QATestArgs()

    result_test = happy_qa.test("../data/qa/test.csv", args=test_args)
    assert type(result_test[0].answer) == str
Beispiel #4
0
def test_qa_with_dic():

    happy_qa = HappyQuestionAnswering()
    train_args = {'learning_rate': 0.01, "num_train_epochs": 1}

    happy_qa.train("../data/qa/train-eval.csv", args=train_args)

    eval_args = {}

    result_eval = happy_qa.eval("../data/qa/train-eval.csv", args=eval_args)
    assert type(result_eval.loss) == float

    test_args = {}

    result_test = happy_qa.test("../data/qa/test.csv", args=test_args)
    assert type(result_test[0].answer) == str
Beispiel #5
0
def test_qa_eval():
    happy_qa = HappyQuestionAnswering(
        model_type='DISTILBERT',
        model_name='distilbert-base-cased-distilled-squad')
    result = happy_qa.eval("../data/qa/train-eval.csv")
    assert result.loss == approx(0.11738169193267822, 0.001)