def test_evaluate_performance_too_many_entities_warning(): X = load_yago3_10() model = TransE(batches_count=200, seed=0, epochs=1, k=5, eta=1, verbose=True) model.fit(X['train']) # no entity list declared with pytest.warns(UserWarning): evaluate_performance(X['test'][::100], model, verbose=True, corrupt_side='o') # with larger than threshold entity list with pytest.warns(UserWarning): # TOO_MANY_ENT_TH threshold is set to 50,000 entities. Using explicit value to comply with linting # and thus avoiding exporting unused global variable. entities_subset = np.union1d(np.unique(X["train"][:, 0]), np.unique(X["train"][:, 2]))[:50000] evaluate_performance(X['test'][::100], model, verbose=True, corrupt_side='o', entities_subset=entities_subset) # with small entity list (no exception expected) evaluate_performance(X['test'][::100], model, verbose=True, corrupt_side='o', entities_subset=entities_subset[:10]) # with smaller dataset, no entity list declared (no exception expected) X_wn18rr = load_wn18rr() model_wn18 = TransE(batches_count=200, seed=0, epochs=1, k=5, eta=1, verbose=True) model_wn18.fit(X_wn18rr['train']) evaluate_performance(X_wn18rr['test'][::100], model_wn18, verbose=True, corrupt_side='o')
def test_yago_3_10(): yago_3_10 = load_yago3_10() assert len(yago_3_10['train']) == 1079040 assert len(yago_3_10['valid']) == 5000 - 22 assert len(yago_3_10['test']) == 5000 - 18