Пример #1
0
def generate_query(question, question_type, entities, relations, ask_query=False, count_query=False):
    sort_query = False
    h1_threshold = 9999999

    double_relation = False
    if double_relation_classifier is not None:
        double_relation = double_relation_classifier.predict([question])
        if double_relation == 1:
            double_relation = True

    graph = Graph(kb)
    graph.find_minimal_subgraph(entities, relations, double_relation=double_relation, ask_query=ask_query,
                                sort_query=sort_query, h1_threshold=h1_threshold)

    query_builder = QueryBuilder()
    valid_walks = query_builder.to_where_statement(graph, answer_parser.parse_queryresult, ask_query=ask_query,
                                                   count_query=count_query, sort_query=sort_query)

    if question_type == 0 and len(relations) == 1:
        double_relation = True
        graph = Graph(kb)
        query_builder = QueryBuilder()
        graph.find_minimal_subgraph(entities, relations, double_relation=double_relation, ask_query=ask_query,
                                    sort_query=sort_query, h1_threshold=h1_threshold)
        valid_walks_new = query_builder.to_where_statement(graph, answer_parser.parse_queryresult, ask_query=ask_query,
                                                           count_query=count_query, sort_query=sort_query)
        valid_walks.extend(valid_walks_new)

    args = Struct()
    base_path = "./learning/treelstm/"
    args.save = os.path.join(base_path, "checkpoints/")
    #args.expname = "lc_quad,epoch=5,train_loss=0.08340245485305786"
    args.expname = "lc_quad,epoch=15,train_loss=0.09691771119832993"
    args.mem_dim = 150
    args.hidden_dim = 50
    args.num_classes = 2
    args.input_dim = 300
    args.sparse = False
    args.lr = 0.01
    args.wd = 1e-4
    args.data = os.path.join(base_path, "data/lc_quad/")
    args.cuda = False
    # args.cuda = True
    try:
        scores = rank(args, question, valid_walks)
    except FileNotFoundError as error:
        print(error)
        scores = [1 for _ in valid_walks]
    for idx, item in enumerate(valid_walks):
        if idx >= len(scores):
            item["confidence"] = 0.3
        else:
            item["confidence"] = float(scores[idx] - 1)

    return valid_walks
Пример #2
0
    def generate_query(self, question, entities, relations, h1_threshold=None):
        ask_query = False
        sort_query = False
        count_query = False

        question_type = 0
        if self.question_classifier is not None:
            self.question_classifier.predict([question])
        if question_type == 2:
            count_query = True
        elif question_type == 1:
            ask_query = True

        double_relation = False
        if self.double_relation_classifer is not None:
            double_relation = self.double_relation_classifer.predict(
                [question])
            if double_relation == 1:
                double_relation = True

        graph = Graph(self.kb)
        query_builder = QueryBuilder()
        graph.find_minimal_subgraph(entities,
                                    relations,
                                    double_relation=double_relation,
                                    ask_query=ask_query,
                                    sort_query=sort_query,
                                    h1_threshold=h1_threshold)
        valid_walks = query_builder.to_where_statement(
            graph,
            self.parser.parse_queryresult,
            ask_query=ask_query,
            count_query=count_query,
            sort_query=sort_query)

        args = Struct()
        base_path = "./learning/treelstm/"
        args.save = os.path.join(base_path, "checkpoints/")
        args.expname = "lc_quad"
        args.mem_dim = 150
        args.hidden_dim = 50
        args.num_classes = 2
        args.input_dim = 300
        args.sparse = ""
        args.lr = 0.01
        args.wd = 1e-4
        args.data = os.path.join(base_path, "data/lc_quad/")
        args.cuda = False
        scores = self.rank(args, question, valid_walks)
        for idx, item in enumerate(valid_walks):
            item["confidence"] = scores[idx] - 1

        return valid_walks, question_type
Пример #3
0
    def generate_query(self, question, entities, relations, h1_threshold=9999999, question_type=None):
        ask_query = False
        sort_query = False
        count_query = False

        if question_type is None:
            question_type = 0
            if self.question_classifier is not None:
                question_type = self.question_classifier.predict([question])
        if question_type == 2:
            count_query = True
        elif question_type == 1:
            ask_query = True

        type_confidence = self.question_classifier.predict_proba([question])[0][question_type]
        if isinstance(self.question_classifier.predict_proba([question])[0][question_type], (np.ndarray, list)):
            type_confidence = type_confidence[0]

        double_relation = False
        if self.double_relation_classifer is not None:
            double_relation = self.double_relation_classifer.predict([question])
            if double_relation == 1:
                double_relation = True

        graph = Graph(self.kb)
        query_builder = QueryBuilder()
        graph.find_minimal_subgraph(entities, relations, double_relation=double_relation, ask_query=ask_query,
                                    sort_query=sort_query, h1_threshold=h1_threshold)

        valid_walks = query_builder.to_where_statement(graph, self.parser.parse_queryresult, ask_query=ask_query,
                                                       count_query=count_query, sort_query=sort_query)

        if question_type == 0 and len(relations) == 1:
            double_relation = True
            graph = Graph(self.kb)
            query_builder = QueryBuilder()
            graph.find_minimal_subgraph(entities, relations, double_relation=double_relation, ask_query=ask_query,
                                        sort_query=sort_query, h1_threshold=h1_threshold)
            valid_walks_new = query_builder.to_where_statement(graph, self.parser.parse_queryresult,
                                                               ask_query=ask_query,
                                                               count_query=count_query, sort_query=sort_query)
            valid_walks.extend(valid_walks_new)

        args = Struct()
        base_path = "./learning/treelstm/"
        args.expname = "lc_quad,epoch=5,train_loss=0.08340245485305786"
        args.mem_dim = 150
        args.hidden_dim = 50
        args.num_classes = 2
        args.input_dim = 300
        args.sparse = False
        args.lr = 0.01
        args.wd = 1e-4
        args.data = os.path.join(base_path, "data/lc_quad/")
        args.cuda = False
        # args.cuda = True
        try:
            scores = self.rank(args, question, valid_walks)
        except:
            scores = [1 for _ in valid_walks]
        for idx, item in enumerate(valid_walks):
            if idx >= len(scores):
                item["confidence"] = 0.3
            else:
                item["confidence"] = float(scores[idx] - 1)

        return valid_walks, question_type, type_confidence
Пример #4
0
def qg(linker, kb, parser, qapair, question_type_classifier, double_relation_classifier, force_gold=True):
    logger.info(qapair.sparql)
    logger.info(qapair.question.text)

    # Get Answer from KB online
    status, raw_answer_true = kb.query(qapair.sparql.query.replace("https", "http"))
    answerset_true = AnswerSet(raw_answer_true, parser.parse_queryresult)
    qapair.answerset = answerset_true

    ask_query = "ASK " in qapair.sparql.query
    count_query = "COUNT(" in qapair.sparql.query
    sort_query = "order by" in qapair.sparql.raw_query.lower()
    entities, ontologies = linker.do(qapair, force_gold=force_gold)

    if entities is None or ontologies is None:
        return "-Linker_failed", []

    graph = Graph(kb)
    queryBuilder = QueryBuilder()

    logger.info("start finding the minimal subgraph")

    graph.find_minimal_subgraph(entities, ontologies, ask_query=ask_query, sort_query=sort_query)
    logger.info(graph)
    wheres = queryBuilder.to_where_statement(graph, parser.parse_queryresult, ask_query=ask_query,
                                             count_query=count_query, sort_query=sort_query)

    output_where = [{"query": " .".join(item["where"]), "correct": False, "target_var": "?u_0"} for item in wheres]
    for item in list(output_where):
        logger.info(item["query"])
    if len(wheres) == 0:
        return "-without_path", output_where
    correct = False

    for idx in range(len(wheres)):
        where = wheres[idx]

        if "answer" in where:
            answerset = where["answer"]
            target_var = where["target_var"]
        else:
            target_var = "?u_" + str(where["suggested_id"])
            raw_answer = kb.query_where(where["where"], target_var, count_query, ask_query)
            answerset = AnswerSet(raw_answer, parser.parse_queryresult)

        output_where[idx]["target_var"] = target_var
        sparql = SPARQL(kb.sparql_query(where["where"], target_var, count_query, ask_query), ds.parser.parse_sparql)
        if (answerset == qapair.answerset) != (sparql == qapair.sparql):
            print("error")

        if answerset == qapair.answerset:
            correct = True
            output_where[idx]["correct"] = True
            output_where[idx]["target_var"] = target_var
        else:
            if target_var == "?u_0":
                target_var = "?u_1"
            else:
                target_var = "?u_0"
            raw_answer = kb.query_where(where["where"], target_var, count_query, ask_query)
            print("Q_H ",)
            print(raw_answer)
            print("Q_")
            answerset = AnswerSet(raw_answer, parser.parse_queryresult)

            sparql = SPARQL(kb.sparql_query(where["where"], target_var, count_query, ask_query), ds.parser.parse_sparql)
            if (answerset == qapair.answerset) != (sparql == qapair.sparql):
                print("error")

            if answerset == qapair.answerset:
                correct = True
                output_where[idx]["correct"] = True
                output_where[idx]["target_var"] = target_var

    return "correct" if correct else "-incorrect", output_where
Пример #5
0
ds_train = qapairs_to_triple(ds_train)
ds_test = qapairs_to_triple(ds_test)

model = DSSM(max_steps=10)
# questions, queries, ids = Preprocessor.qapair_to_hash(ds_train)
# model.train([questions, queries])

# questions, queries, ids = Preprocessor.qapair_to_hash(ds_test)
# model.test([questions, queries])

new_ds_test = []
for item in ds_test:
    ask_query = "ASK " in item["query"]
    count_query = "COUNT(" in item["query"]
    sort_query = "order by" in item["query"].lower()

    entities, ontologies = [u for u in item["uris"] if u.is_entity()], \
                           [u for u in item["uris"] if u.is_ontology()]

    graph = Graph(kb)
    graph.find_minimal_subgraph(entities, ontologies, ask_query, sort_query)
    where = graph.to_where_statement()
    if len(where) > 1:
        for w in where:
            new_ds_test.append({"id": item["id"], "question": item["question"], "query": " ".join(w[1]),
                                "uris": item["uris"]})

questions, sparqls = Preprocessor.qapair_to_hash(new_ds_test)
model.similarity(questions, sparqls)
Пример #6
0
    def generate_query(self, question, entities, relations, h1_threshold=None, question_type=None):
        ask_query = False
        sort_query = False
        count_query = False
        print('orchastrator: generate_query')

        if question_type is None:
            question_type = 0
            if self.question_classifier is not None:
                question_type = self.question_classifier.predict([question])
                print('question_type predicted by classifier:', question_type)
        if question_type == 2:
            count_query = True
        elif question_type == 1:
            ask_query = True

        type_confidence = self.question_classifier.predict_proba([question])[0][question_type]
        if isinstance(self.question_classifier.predict_proba([question])[0][question_type], (np.ndarray, list)):
            type_confidence = type_confidence[0]

        double_relation = False
        # if self.double_relation_classifer is not None:
        #     double_relation = self.double_relation_classifer.predict([question])
        #     if double_relation == 1:
        #         double_relation = True

        graph = Graph(self.kb)
        query_builder = QueryBuilder()
        print('params to find minimal subgraph:')
        print('entities:', entities)
        print('relations:', relations)
        print('double_relation:', double_relation)
        print('h1_threshold:', h1_threshold)
        graph.find_minimal_subgraph(entities, relations, double_relation=double_relation, ask_query=ask_query,
                                    sort_query=sort_query, h1_threshold=h1_threshold)
        valid_walks = query_builder.to_where_statement(graph, self.parser.parse_queryresult, ask_query=ask_query,
                                                       count_query=count_query, sort_query=sort_query)
        # if question_type == 0 and len(relations) == 1:
        #     double_relation = True
        #     graph = Graph(self.kb)
        #     query_builder = QueryBuilder()
        #     graph.find_minimal_subgraph(entities, relations, double_relation=double_relation, ask_query=ask_query,
        #                                 sort_query=sort_query, h1_threshold=h1_threshold)
        #     valid_walks_new = query_builder.to_where_statement(graph, self.parser.parse_queryresult,
        #                                                        ask_query=ask_query,
        #                                                        count_query=count_query, sort_query=sort_query)
        #     valid_walks.extend(valid_walks_new)
        if len(valid_walks) == 0:
            return valid_walks, question_type, 0
        args = Struct()
        base_path = "./learning/treelstm/"
        args.save = os.path.join(base_path, "checkpoints/")
        args.expname = "lc_quad"
        args.mem_dim = 150
        args.hidden_dim = 50
        args.num_classes = 2
        args.input_dim = 300
        args.sparse = False
        args.lr = 0.01
        args.wd = 1e-4
        args.data = os.path.join(base_path, "data/lc_quad/")
        args.cuda = False
        try:
            scores = self.rank(args, question, valid_walks)
        except Exception as e:
            print('rank deu errado!')
            print(e)
            scores = [1 for _ in valid_walks]
        for idx, item in enumerate(valid_walks):
            if idx >= len(scores):
                item["confidence"] = 0.3
            else:
                item["confidence"] = float(scores[idx] - 1)

        return valid_walks, question_type, type_confidence