コード例 #1
0
ファイル: test_models.py プロジェクト: zdqf/AmpliGraph
def test_output_sizes():
    ''' Test to check whether embedding matrix sizes match the input data (num rel/ent and k)
    '''
    def perform_test():
        X = load_wn18rr()
        k = 5
        unique_entities = np.unique(
            np.concatenate([X['train'][:, 0], X['train'][:, 2]], 0))
        unique_relations = np.unique(X['train'][:, 1])
        model = TransE(batches_count=100,
                       seed=555,
                       epochs=1,
                       k=k,
                       loss='multiclass_nll',
                       loss_params={'margin': 5},
                       verbose=True,
                       optimizer='sgd',
                       optimizer_params={'lr': 0.001})
        model.fit(X['train'])
        # verify ent and rel shapes
        assert (model.trained_model_params[0].shape[0] == len(unique_entities))
        assert (
            model.trained_model_params[1].shape[0] == len(unique_relations))
        # verify k
        assert (model.trained_model_params[0].shape[1] == k)
        assert (model.trained_model_params[1].shape[1] == k)

    # Normal mode
    perform_test()

    # Large graph mode
    set_entity_threshold(10)
    perform_test()
    reset_entity_threshold()
コード例 #2
0
def test_evaluate_with_ent_subset_large_graph():
    set_entity_threshold(1)
    X = load_wn18()
    model = ComplEx(batches_count=10, seed=0, epochs=2, k=10, eta=1,
                optimizer='sgd', optimizer_params={'lr': 1e-5},
                loss='pairwise', loss_params={'margin': 0.5},
                regularizer='LP', regularizer_params={'p': 2, 'lambda': 1e-5},
                verbose=True)

    model.fit(X['train'])

    X_filter = np.concatenate((X['train'], X['valid'], X['test']))
    all_nodes = set(X_filter[:, 0]).union(X_filter[:, 2])
    
    entities_subset = np.random.choice(list(all_nodes), 100, replace=False)
    
    ranks = evaluate_performance(X['test'][::10],
                             model=model,
                             filter_triples=X_filter,
                             corrupt_side='o',
                             use_default_protocol=False,
                             entities_subset=list(entities_subset),
                             verbose=True)
    assert np.sum(ranks > (100 + 1)) == 0, "No ranks must be greater than 101"
    reset_entity_threshold()
コード例 #3
0
ファイル: test_models.py プロジェクト: zdqf/AmpliGraph
def test_large_graph_mode_adam():
    set_entity_threshold(10)
    X = load_wn18()
    model = ComplEx(batches_count=100,
                    seed=555,
                    epochs=1,
                    k=50,
                    loss='multiclass_nll',
                    loss_params={'margin': 5},
                    verbose=True,
                    optimizer='adam',
                    optimizer_params={'lr': 0.001})
    try:
        model.fit(X['train'])
    except Exception as e:
        print(str(e))

    reset_entity_threshold()
コード例 #4
0
ファイル: test_models.py プロジェクト: zdqf/AmpliGraph
def test_large_graph_mode():
    set_entity_threshold(10)
    X = load_wn18()
    model = ComplEx(batches_count=100,
                    seed=555,
                    epochs=1,
                    k=50,
                    loss='multiclass_nll',
                    loss_params={'margin': 5},
                    verbose=True,
                    optimizer='sgd',
                    optimizer_params={'lr': 0.001})
    model.fit(X['train'])
    X_filter = np.concatenate((X['train'], X['valid'], X['test']), axis=0)
    evaluate_performance(X['test'][::1000],
                         model,
                         X_filter,
                         verbose=True,
                         corrupt_side='s,o')

    y = model.predict(X['test'][:1])
    print(y)
    reset_entity_threshold()
コード例 #5
0
ファイル: main.py プロジェクト: mhmgad/ExCut
from ampligraph.latent_features import set_entity_threshold

from excut.explanations_mining.explaining_engines_extended import PathBasedClustersExplainerExtended
from excut.explanations_mining.simple_miner.description_miner_extended import ExplanationStructure
from excut.feedback import strategies
from excut.kg.kg_query_interface_extended import RdflibKGQueryInterfaceExtended, EndPointKGQueryInterfaceExtended
from excut.kg.kg_slicing import KGSlicer
from excut.utils.output_utils import write_triples

set_entity_threshold(10e6)
# set_entity_threshold(10e4)s

import json
import time
import os, argparse
from excut.utils.logging import logger

from rdflib import Graph

from excut.pipeline.explainable_clustering import ExClusteringImpl
from excut.kg.kg_indexing import Indexer
from excut.kg.utils.Constants import DEFUALT_AUX_RELATION
from excut.kg.kg_triples_source import FileTriplesSource
from excut.clustering.target_entities import EntitiesLabelsFile
from excut.embedding.embedding_adapters import AmpligraphEmbeddingAdapter, UpdateDataMode, \
    UpdateMode


def initiate_problem(args, time_now):
    # objective quality
    objective_quality_measure = args.objective_quality