示例#1
0
def test_output_sizes():
    ''' Test to check whether embedding matrix sizes match the input data (num rel/ent and k)
    '''
    def perform_test():
        X = load_wn18rr()
        k = 5
        unique_entities = np.unique(
            np.concatenate([X['train'][:, 0], X['train'][:, 2]], 0))
        unique_relations = np.unique(X['train'][:, 1])
        model = TransE(batches_count=100,
                       seed=555,
                       epochs=1,
                       k=k,
                       loss='multiclass_nll',
                       loss_params={'margin': 5},
                       verbose=True,
                       optimizer='sgd',
                       optimizer_params={'lr': 0.001})
        model.fit(X['train'])
        # verify ent and rel shapes
        assert (model.trained_model_params[0].shape[0] == len(unique_entities))
        assert (
            model.trained_model_params[1].shape[0] == len(unique_relations))
        # verify k
        assert (model.trained_model_params[0].shape[1] == k)
        assert (model.trained_model_params[1].shape[1] == k)

    # Normal mode
    perform_test()

    # Large graph mode
    set_entity_threshold(10)
    perform_test()
    reset_entity_threshold()
示例#2
0
def test_evaluate_with_ent_subset_large_graph():
    set_entity_threshold(1)
    X = load_wn18()
    model = ComplEx(batches_count=10, seed=0, epochs=2, k=10, eta=1,
                optimizer='sgd', optimizer_params={'lr': 1e-5},
                loss='pairwise', loss_params={'margin': 0.5},
                regularizer='LP', regularizer_params={'p': 2, 'lambda': 1e-5},
                verbose=True)

    model.fit(X['train'])

    X_filter = np.concatenate((X['train'], X['valid'], X['test']))
    all_nodes = set(X_filter[:, 0]).union(X_filter[:, 2])
    
    entities_subset = np.random.choice(list(all_nodes), 100, replace=False)
    
    ranks = evaluate_performance(X['test'][::10],
                             model=model,
                             filter_triples=X_filter,
                             corrupt_side='o',
                             use_default_protocol=False,
                             entities_subset=list(entities_subset),
                             verbose=True)
    assert np.sum(ranks > (100 + 1)) == 0, "No ranks must be greater than 101"
    reset_entity_threshold()
示例#3
0
def test_large_graph_mode_adam():
    set_entity_threshold(10)
    X = load_wn18()
    model = ComplEx(batches_count=100,
                    seed=555,
                    epochs=1,
                    k=50,
                    loss='multiclass_nll',
                    loss_params={'margin': 5},
                    verbose=True,
                    optimizer='adam',
                    optimizer_params={'lr': 0.001})
    try:
        model.fit(X['train'])
    except Exception as e:
        print(str(e))

    reset_entity_threshold()
示例#4
0
def test_large_graph_mode():
    set_entity_threshold(10)
    X = load_wn18()
    model = ComplEx(batches_count=100,
                    seed=555,
                    epochs=1,
                    k=50,
                    loss='multiclass_nll',
                    loss_params={'margin': 5},
                    verbose=True,
                    optimizer='sgd',
                    optimizer_params={'lr': 0.001})
    model.fit(X['train'])
    X_filter = np.concatenate((X['train'], X['valid'], X['test']), axis=0)
    evaluate_performance(X['test'][::1000],
                         model,
                         X_filter,
                         verbose=True,
                         corrupt_side='s,o')

    y = model.predict(X['test'][:1])
    print(y)
    reset_entity_threshold()