Пример #1
0
def test_evaluate_performance_too_many_entities_warning():
    X = load_yago3_10()
    model = TransE(batches_count=200,
                   seed=0,
                   epochs=1,
                   k=5,
                   eta=1,
                   verbose=True)
    model.fit(X['train'])

    # no entity list declared
    with pytest.warns(UserWarning):
        evaluate_performance(X['test'][::100],
                             model,
                             verbose=True,
                             corrupt_side='o')

    # with larger than threshold entity list
    with pytest.warns(UserWarning):
        # TOO_MANY_ENT_TH threshold is set to 50,000 entities. Using explicit value to comply with linting
        # and thus avoiding exporting unused global variable.
        entities_subset = np.union1d(np.unique(X["train"][:, 0]),
                                     np.unique(X["train"][:, 2]))[:50000]
        evaluate_performance(X['test'][::100],
                             model,
                             verbose=True,
                             corrupt_side='o',
                             entities_subset=entities_subset)

    # with small entity list (no exception expected)
    evaluate_performance(X['test'][::100],
                         model,
                         verbose=True,
                         corrupt_side='o',
                         entities_subset=entities_subset[:10])

    # with smaller dataset, no entity list declared (no exception expected)
    X_wn18rr = load_wn18rr()
    model_wn18 = TransE(batches_count=200,
                        seed=0,
                        epochs=1,
                        k=5,
                        eta=1,
                        verbose=True)
    model_wn18.fit(X_wn18rr['train'])
    evaluate_performance(X_wn18rr['test'][::100],
                         model_wn18,
                         verbose=True,
                         corrupt_side='o')
Пример #2
0
def test_yago_3_10():
    yago_3_10 = load_yago3_10()
    assert len(yago_3_10['train']) == 1079040
    assert len(yago_3_10['valid']) == 5000 - 22
    assert len(yago_3_10['test']) == 5000 - 18