예제 #1
0
 def __init__(self, kg_query_interface: KGQueryInterface, relation=DEFUALT_AUX_RELATION, quality='x_coverage', quality_aggregation=max):
     """
     :param kg_query_interface: interface for the KG.
     :param relation: the relation used in the predicted triple  (optional)
     :param quality: objective quality measure for ranking the predictions (optional) by default
             the exclusive coverage of the rules is used
     :param quality_aggregation: the methd used for aggregating the score if multiple rules infers the same fact
            (optional) by default max is used.
     """
     super(SparqlBasedDeductionEngine, self).__init__()
     self.relation = relation
     self.query_executer = kg_query_interface
     self.quality = quality
     self.quality_aggregation = quality_aggregation
     self.labels_indexer=Indexer(store=kg_query_interface.type, endpoint=kg_query_interface.endpoint, graph= kg_query_interface.labels_graph,identifier=kg_query_interface.labels_identifier)
예제 #2
0
 def __init__(self,
              kg_query_interface: KGQueryInterface,
              labels_indexer=None,
              relation=DEFUALT_AUX_RELATION,
              min_coverage=0.4,
              per_pattern_binding_limit=20,
              top=10,
              quality_method='x_coverage',
              with_constants=True,
              language_bias=None):
     super().__init__()
     self.language_bias = language_bias if language_bias else {
         'max_length': 2
     }
     self.with_constants = with_constants
     self.quality_method = quality_method
     self.top = top
     self.relation = relation
     self.indexer = labels_indexer if labels_indexer else \
         Indexer(store=kg_query_interface.type, endpoint=kg_query_interface.endpoint,
                 graph= kg_query_interface.labels_graph,identifier=kg_query_interface.labels_identifier)
     self.description_miner = DescriptionMiner(
         kg_query_interface,
         per_pattern_binding_limit,
         with_constants=self.with_constants)
     self.query_interface = kg_query_interface
     self.min_coverage = min_coverage
예제 #3
0
class SparqlBasedDeductionEngine(DeductionEngine):
    """
    Deduction engine that converts the rules to sparql and fire them over the KG.

    The rule-based_deduction takes care of consolidating similar predictions
    """

    def __init__(self, kg_query_interface: KGQueryInterface, relation=DEFUALT_AUX_RELATION, quality='x_coverage', quality_aggregation=max):
        """
        :param kg_query_interface: interface for the KG.
        :param relation: the relation used in the predicted triple  (optional)
        :param quality: objective quality measure for ranking the predictions (optional) by default
                the exclusive coverage of the rules is used
        :param quality_aggregation: the methd used for aggregating the score if multiple rules infers the same fact
               (optional) by default max is used.
        """
        super(SparqlBasedDeductionEngine, self).__init__()
        self.relation = relation
        self.query_executer = kg_query_interface
        self.quality = quality
        self.quality_aggregation = quality_aggregation
        self.labels_indexer=Indexer(store=kg_query_interface.type, endpoint=kg_query_interface.endpoint, graph= kg_query_interface.labels_graph,identifier=kg_query_interface.labels_identifier)

    def infer(self, descriptions_list, target_entities=None, min_quality=0, topk=-1, output_filepath=None,
              clear_target_entities=True):
        """
        Infer new facts for a giving set of descriptions

        :param descriptions_list: list of explantions/descriptions rules
        :param target_entities: entities and their labels for which predictions are generated
        :param min_quality: minimum aggregated quality for the predictions
        :param topk: k *distinct* highest quality predictions per entity,
        :param output_filepath: predictions output file.
        :param clear_target_entities: clear indexed target entities after done inference
        :return: dictionary of predicted entity-clusters assignments
        """

        if target_entities:
            self.labels_indexer.index_triples(target_entities)
            self.relation=target_entities.get_relation()

        predictions = list(map(self._infer_single, descriptions_list))

        per_entity_predictions = self.consolidate(predictions)

        per_entity_predictions = self._merge_and_sort_cut(per_entity_predictions, min_quality, topk=topk)

        if output_filepath:
            dump_predictions_map(per_entity_predictions, output_filepath, triple_format=True, topk=topk, with_weight=True,
                                 with_description=False, quality=self.quality)

        if target_entities and clear_target_entities:
            self.labels_indexer.drop()

        return per_entity_predictions

    def consolidate(self, predictions):
        """
        Combine predictions from different rules

        :param predictions: list of generated predictions
        :return: combined single prediction with several sources for equivalent predictions
        :rtype: dict
        """

        # per_var_predictions = defaultdict(lambda: defaultdict(list))
        # for p in chain.from_iterable(predictions):
        #     per_var_predictions[p.get_subject()][p.get_object()].append(p)

        per_entity_predictions = defaultdict(lambda: defaultdict(Prediction))
        for p in list(chain.from_iterable(predictions)):
            cons_pred = per_entity_predictions[p.get_subject()][p.get_object()]
            cons_pred.triple = p.triple
            cons_pred.all_sources += p.all_sources

        return per_entity_predictions

    def _merge_and_sort_cut(self, per_entity_prediction, threshold=0, topk=-1):
        """
        Merge the the inferred facts in case of functional predicates

        :param per_entity_prediction:
        :return:
        """

        def quality_method(p):
            return p.get_quality(self.quality, self.quality_aggregation)

        per_entity_prediction_filtered = defaultdict(list)
        for sub, per_obj_predictions in per_entity_prediction.items():
            # print([(k, p.triple[2], qaulity_method(p)) for k, p in per_obj_predictions.items()])
            merged_predictions = list(
                filter(lambda p: quality_method(p) > threshold, list(per_obj_predictions.values())))

            merged_predictions.sort(key=quality_method, reverse=True)

            include = topk if topk > 0 else len(merged_predictions)
            per_entity_prediction_filtered[sub] = merged_predictions[:include]

        return per_entity_prediction_filtered

    def _infer_single(self, description: Description):
        """
        Infer new facts for the given Description
        :param description:
        :return:
        """
        bindings = self.query_executer.get_arguments_bindings(description,
                                                              restriction_pattern=Description([self.relation],
                                                                                              ['?x', '?z'], [True]))
        head = description.head

        # only supports p(?x,CONSTANT)
        predictions = [Prediction((b, head[1], head[2]), [description]) for b in bindings]

        return predictions
예제 #4
0
파일: cli_explain.py 프로젝트: mhmgad/ExCut
    parser.add_argument("-ll", "--max_length", help="maximum length of description", default=2, type=int)
    parser.add_argument("-ls", "--language_structure", help="Structure of the learned description", default="PATH")

    parser.add_argument("-host", "--host", help="SPARQL endpoint host and ip host_ip:port", default="http://halimede:8890/sparql")


    args = parser.parse_args()

    idnt = args.kg_identifier

    if args.kg:
        kg_triples = FileTriplesSource(args.kg,
                                        prefix='http://exp-data.org/', safe_urls=True)

        kg_indexer = Indexer(endpoint='http://halimede:8890/sparql', identifier=idnt)
        if not kg_indexer.graph_exists():
            kg_indexer.index_triples(kg_triples)

    vos_executer=EndPointKGQueryInterfaceExtended(endpoint='http://halimede:8890/sparql',
                                                  identifiers=[idnt, idnt + '.alltypes'],
                                                  labels_identifier=idnt +'.labels.gt')

    t_entities=EntitiesLabelsFile(args.target_entities) #,
                                  # prefix='http://exp-data.org/', safe_urls=True)

    # cls = ['http://clusteringtype#UndergraduateCourse','http://clusteringtype#GraduateCourse']
    # cls=['http://clusteringtype#PublicationByProfessor','http://clusteringtype#PublicationByGraduateStudent']
    # cls=t_entities.get_uniq_labels()
    # labels_indexer = Indexer(endpoint='http://halimede:8890/sparql', identifier=idnt + '.labels.gt')
    cd = PathBasedClustersExplainerExtended(vos_executer, relation=t_entities.get_relation(),
예제 #5
0
파일: main.py 프로젝트: mhmgad/ExCut
def initiate_problem(args, time_now):
    # objective quality
    objective_quality_measure = args.objective_quality
    # kg file
    kg_filepath = args.kg
    # initialize output_dir
    # output_dir = args.output_folder if args.output_folder else os.path.join(args.output_folder, "run_%s" % time_now)
    output_dir = os.path.join(args.output_folder, "run_%s" % time_now)
    embedding_output_dir = os.path.join(output_dir, 'embedding')
    base_embedding_dir = args.embedding_dir
    # encoding_dict_dir = args.encoding_dict_dir
    # embedding_base_model = os.path.join(embedding_dir, 'base')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    logger.info('Output Dir: %s' % output_dir)
    # Traget entities
    data_prefix = args.data_prefix
    target_entities_filepath = args.target_entities
    target_entities = EntitiesLabelsFile(target_entities_filepath,
                                         prefix=data_prefix,
                                         safe_urls=args.data_safe_urls)
    number_of_clusters = args.number_of_clusters if args.number_of_clusters else target_entities.get_num_clusters(
    )
    logger.info('Number of clusters %i' % number_of_clusters)
    # file source to read the kg
    kg_file_source = FileTriplesSource(kg_filepath,
                                       prefix=data_prefix,
                                       safe_urls=args.data_safe_urls)
    # place_holder_triples = PlaceHolderTriplesSource(10, 10, prefix=data_prefix, safe_urls=args.data_safe_urls)
    kg_and_place_holders = kg_file_source  #JointTriplesSource(kg_file_source, place_holder_triples)
    kg_identfier = args.kg_identifier
    labels_identifier = kg_identfier + '_%s.labels' % time_now
    logger.info("KG Identifiers: %s\t%s" % (kg_identfier, labels_identifier))
    # index data if required
    host = args.host
    graph = None
    labels_indexer = None
    if args.index is not None:
        if args.index == 'remote':
            indexer = Indexer(endpoint=host, identifier=kg_identfier)
            labels_indexer = Indexer(endpoint=host,
                                     identifier=labels_identifier)
        else:
            indexer = Indexer(store='memory', identifier=kg_identfier)
            labels_indexer = Indexer(store='memory',
                                     identifier=labels_identifier)

        if args.drop_index or not indexer.graph_exists():
            logger.info("KG will be indexed to %s (%s %s)" %
                        (args.index, args.drop_index, indexer.graph_exists()))
            graph = indexer.index_triples(kg_file_source,
                                          drop_old=args.drop_index)
            logger.info("Done indexing!")
    logger.info("Embedding adapter chosen: %s" % args.embedding_adapter)
    update_data_mode = UpdateDataMode[args.update_data_mode.upper()]
    update_mode = UpdateMode[args.update_mode.upper()]
    iterations_history = args.update_triples_history
    update_lr = args.update_learning_rate
    # executor
    if args.index == 'remote':
        kg_query_interface = EndPointKGQueryInterfaceExtended(
            endpoint=host,
            identifiers=[kg_identfier, kg_identfier + '.types'],
            labels_identifier=labels_identifier)
    else:
        kg_query_interface = RdflibKGQueryInterfaceExtended(
            graphs=[graph, Graph(identifier=labels_identifier)])
    update_subKG = None
    if update_data_mode == UpdateDataMode.ASSIGNMENTS_SUBGRAPH or args.sub_kg or \
            update_data_mode == UpdateDataMode.SUBGRAPH_ONLY:
        logger.info("Generating SubKG for target entities!")
        # if we
        if args.context_filepath and os.path.exists(args.context_filepath):
            update_subKG = FileTriplesSource(args.context_filepath,
                                             prefix=data_prefix,
                                             safe_urls=args.data_safe_urls)
        else:
            kg_slicer = KGSlicer(kg_query_interface)
            update_subKG = kg_slicer.subgraph(target_entities.get_entities(),
                                              args.update_context_depth)
            # write_triples(update_subKG, os.path.join(output_dir, 'subkg.tsv'))
            if args.context_filepath:
                write_triples(update_subKG, args.context_filepath)

        logger.info("Done Generating SubKG for target entities!")
        if args.sub_kg:  # only use the related subset of the graph b
            kg_and_place_holders = update_subKG  #JointTriplesSource(update_subKG, place_holder_triples)
    # Init embedding
    # if args.embedding_adapter == 'ampligragh':

    embedding_adapter = AmpligraphEmbeddingAdapter(
        embedding_output_dir,
        kg_and_place_holders,
        context_subgraph=update_subKG,
        base_model_folder=base_embedding_dir,
        model_name=args.embedding_method,
        update_mode=update_mode,
        update_data_mode=update_data_mode,
        update_params={'lr': update_lr},
        iterations_history=iterations_history,
        seed=args.seed)
    # elif args.embedding_adapter == 'openke_api':
    #     embedding_adapter = OpenKETFEmbeddingAdapter(embedding_output_dir, kg_and_place_holders,
    #                                                  base_model_folder=base_embedding_dir,
    #                                                  kg_encoding_folder=encoding_dict_dir,
    #                                                  model_name=args.embedding_method)
    # else:
    #     raise Exception("Adapter %s not supported!" % args.embedding_adapter)
    aug_relation = DEFUALT_AUX_RELATION
    # relation_full_url('http://execute_aux.org/auxBelongsTo', data_prefix)
    # clusters explainning engine
    clusters_explainer = PathBasedClustersExplainerExtended(
        kg_query_interface,
        labels_indexer=labels_indexer,
        quality_method=objective_quality_measure,
        relation=aug_relation,
        min_coverage=args.expl_c_coverage,
        with_constants=True,
        language_bias={
            'max_length': args.max_length,
            'structure': ExplanationStructure[args.language_structure]
        })
    # Augmentation strategy
    aug_strategy = strategies.get_strategy(
        args.update_strategy,
        kg_query_interface=kg_query_interface,
        quality_method=objective_quality_measure,
        predictions_min_quality=args.prediction_min_q,
        aux_relation=aug_relation)

    explainable_clustering_engine = ExClusteringImpl(
        target_entities,
        embedding_adapter=embedding_adapter,
        clusters_explainer=clusters_explainer,
        augmentation_strategy=aug_strategy,
        clustering_method=args.clustering_method,
        clustering_params={
            'k': number_of_clusters,
            'distance_metric': args.clustering_distance,
            'p': args.cut_prob
        },
        out_dir=output_dir,
        max_iterations=args.max_iterations,
        save_steps=True,
        objective_quality_measure=objective_quality_measure,
        seed=args.seed)
    with open(os.path.join(output_dir, 'cli_args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)
    return explainable_clustering_engine.run()
예제 #6
0
# to save clustering results as triples
write_triples(clustering_results_as_triples,
              os.path.join(experiment_dir, 'clustering.tsv'))

# evaluate clustering using normal measures and add them to methods results
current_method_result.update(
    clms.evaluate(target_entities.get_labels(),
                  clustering_results_as_triples.get_labels(),
                  verbose=True))

########################## Explian #############################

##### RUN ONLY ONCE ######
kg_indexer = Indexer(store='remote',
                     endpoint='http://badr:8890/sparql',
                     identifier='http://imdb_example.org')
kg_indexer.index_triples(kg_triples, drop_old=False)
############################ End

# Interfaces to locations where the data is indexed (virtuoso).
query_interface = EndPointKGQueryInterfaceExtended(
    endpoint='http://badr:8890/sparql',
    identifiers=['http://imdb_example.org', 'http://imdb_example.org.types'],
    labels_identifier='http://imdb_example.org.labels.m')
# labels_indexer = Indexer(endpoint='http://tracy:8890/sparql', identifier='http://yago-expr.org.labels.m')

# Explaining engine
quality_method = 'x_coverage'
clusters_explainer = PathBasedClustersExplainerExtended(
    query_interface,