Beispiel #1
0
def example_4_5():
    happy_wp = HappyWordPrediction()
    args = WPEvalArgs(preprocessing_processes=2)
    result = happy_wp.eval("../../data/wp/train-eval.txt", args=args)
    print(type(result))  # <class 'happytransformer.happy_trainer.EvalResult'>
    print(result)  # EvalResult(eval_loss=0.459536075592041)
    print(result.loss)  # 0.459536075592041
Beispiel #2
0
def test_mwp_basic():
    MODELS = [('DISTILBERT', 'distilbert-base-uncased', 'pepper'),
              ('BERT', 'bert-base-uncased', '.'),
              ('ALBERT', 'albert-base-v2', 'garlic')]
    for model_type, model_name, top_result in MODELS:
        happy_mwp = HappyWordPrediction(model_type, model_name)
        results = happy_mwp.predict_mask("Please pass the salt and [MASK]", )
        result = results[0]
        assert result.token == top_result
Beispiel #3
0
def example_4_2():
    happy_wp = HappyWordPrediction()
    result = happy_wp.predict_mask(
        "To better the world I would invest in [MASK] and education.", top_k=2)
    print(
        result
    )  # [WordPredictionResult(token='health', score=0.1280556619167328), WordPredictionResult(token='science', score=0.07976455241441727)]
    print(result[1]
          )  # WordPredictionResult(token='science', score=0.07976455241441727)
    print(result[1].token)  # science
Beispiel #4
0
def test_wp_eval_some_settings():
    """
    Test to see what happens when only a subset of the potential settings are used
    :return:
    """
    args = {
        'line_by_line': True,
    }
    happy_wp = HappyWordPrediction('', 'distilroberta-base')
    result = happy_wp.eval("../data/wp/train-eval.txt", args)
    assert type(result.loss) == float
Beispiel #5
0
def test_mwp_top_k():
    happy_mwp = HappyWordPrediction('DISTILBERT', 'distilbert-base-uncased')
    result = happy_mwp.predict_mask("Please pass the salt and [MASK]", top_k=2)
    answer = [
        WordPredictionResult(token='pepper',
                             score=approx(0.2664579749107361, 0.01)),
        WordPredictionResult(token='vinegar',
                             score=approx(0.08760260790586472, 0.01))
    ]

    assert result == answer
Beispiel #6
0
def test_mwp_targets():
    happy_mwp = HappyWordPrediction('DISTILBERT', 'distilbert-base-uncased')
    result = happy_mwp.predict_mask("Please pass the salt and [MASK]",
                                    targets=["water", "spices"])
    answer = [
        WordPredictionResult(token='water',
                             score=approx(0.014856964349746704, 0.01)),
        WordPredictionResult(token='spices',
                             score=approx(0.009040987119078636, 0.01))
    ]
    assert result == answer
Beispiel #7
0
def example_1_2():
    happy_wp = HappyWordPrediction("ALBERT", "albert-xxlarge-v2")
    result = happy_wp.predict_mask(
        "To better the world I would invest in [MASK] and education.",
        top_k=10)
    print(
        result
    )  # [WordPredictionResult(token='infrastructure', score=0.09270179271697998), WordPredictionResult(token='healthcare', score=0.07219093292951584)]
    print(
        result[1]
    )  # WordPredictionResult(token='healthcare', score=0.07219093292951584)
    print(result[1].token)  # healthcare
Beispiel #8
0
def test_wp_train_eval_with_dataclass():

    happy_wp = HappyWordPrediction('', 'distilroberta-base')
    train_args = WPTrainArgs(learning_rate=0.01,
                             line_by_line=True,
                             num_train_epochs=1)

    happy_wp.train("../data/wp/train-eval.txt", args=train_args)

    eval_args = WPEvalArgs(line_by_line=True)

    after_result = happy_wp.eval("../data/wp/train-eval.txt", args=eval_args)
Beispiel #9
0
def example_4_3():
    happy_wp = HappyWordPrediction()
    targets = ["technology", "healthcare"]
    result = happy_wp.predict_mask(
        "To better the world I would invest in [MASK] and education.",
        targets=targets)
    print(
        result
    )  # [WordPredictionResult(token='healthcare', score=0.07380751520395279), WordPredictionResult(token='technology', score=0.009395276196300983)]
    print(
        result[1]
    )  # WordPredictionResult(token='technology', score=0.009395276196300983)
    print(result[1].token)  # technology
Beispiel #10
0
def example_4_1():
    happy_wp = HappyWordPrediction()  # default uses distilbert-base-uncased
    result = happy_wp.predict_mask("I think therefore I [MASK]")
    print(type(result))  # <class 'list'>
    print(result
          )  # [WordPredictionResult(token='am', score=0.10172799974679947)]
    print(
        type(result[0])
    )  # <class 'happytransformer.happy_word_prediction.WordPredictionResult'>
    print(result[0]
          )  # [WordPredictionResult(token='am', score=0.10172799974679947)]
    print(result[0].token)  # am
    print(result[0].score)  # 0.10172799974679947
Beispiel #11
0
def test_wp_train_eval_with_dic():

    happy_wp = HappyWordPrediction('', 'distilroberta-base')
    train_args = {
        'learning_rate': 0.01,
        'line_by_line': True,
        "num_train_epochs": 1
    }

    happy_wp.train("../data/wp/train-eval.txt", args=train_args)
    eval_args = {'line_by_line': True}

    after_result = happy_wp.eval("../data/wp/train-eval.txt", args=eval_args)
Beispiel #12
0
def example_1_3():
    happy_wp = HappyWordPrediction("ALBERT", "albert-xxlarge-v2")
    targets = ["technology", "healthcare"]
    result = happy_wp.predict_mask(
        "To better the world I would invest in [MASK] and education.",
        targets=targets)
    print(
        result
    )  # [WordPredictionResult(token='healthcare', score=0.07219093292951584), WordPredictionResult(token='technology', score=0.032044216990470886)]
    print(
        result[1]
    )  # WordPredictionResult(token='technology', score=0.032044216990470886)
    print(result[1].token)  # technology
Beispiel #13
0
def test_wp_save_load_train():
    happy_wp = HappyWordPrediction('', 'distilroberta-base')
    output_path = "data/wp-train.json"
    data_path = "../data/wp/train-eval.txt"
    args = ARGS_WP_TRAIN
    args["line_by_line"] = True
    run_save_load(happy_wp, output_path, args, data_path, "train")
Beispiel #14
0
def test_wp_save():
    happy = HappyWordPrediction("BERT", "prajjwal1/bert-tiny")
    happy.save("model/")
    result_before = happy.predict_mask("I think therefore I [MASK]")

    happy = HappyWordPrediction(load_path="model/")
    result_after = happy.predict_mask("I think therefore I [MASK]")

    assert result_before[0].token == result_after[0].token
Beispiel #15
0
def test_wp_train_effectiveness_multi():
    happy_wp = HappyWordPrediction('', 'distilroberta-base')

    before_result = happy_wp.eval("../data/wp/train-eval.txt")

    happy_wp.train("../data/wp/train-eval.txt")
    after_result = happy_wp.eval("../data/wp/train-eval.txt")

    assert after_result.loss < before_result.loss
Beispiel #16
0
def test_wp_train_default():
    happy_wp = HappyWordPrediction('', 'distilroberta-base')
    happy_wp.train("../data/wp/train-eval.txt")
Beispiel #17
0
def test_wp_train_line_by_line():
    happy_wp = HappyWordPrediction('', 'distilroberta-base')
    happy_wp.train("../data/wp/train-eval.txt",
                   args=WPTrainArgs(line_by_line=True))
Beispiel #18
0
def example_4_4():
    happy_wp = HappyWordPrediction()
    args = WPTrainArgs(num_train_epochs=1)
    happy_wp.train("../../data/wp/train-eval.txt", args=args)
Beispiel #19
0
def example_4_0():
    happy_wp_distilbert = HappyWordPrediction()  # default
    happy_wp_albert = HappyWordPrediction("ALBERT", "albert-base-v2")
    happy_wp_bert = HappyWordPrediction("BERT", "bert-base-uncased")
    happy_wp_roberta = HappyWordPrediction("ROBERTA", "roberta-base")
Beispiel #20
0
def test_wp_high_k():

    happy_wp = HappyWordPrediction("ALBERT", "albert-base-v2")
    results = happy_wp.predict_mask("Please pass the salt and [MASK]",
                                    top_k=3000)
    assert results[0].token == "garlic"
Beispiel #21
0
def test_wp_eval_basic():
    happy_wp = HappyWordPrediction('', 'distilroberta-base')
    result = happy_wp.eval("../data/wp/train-eval.txt")
    assert type(result.loss) == float