class TestNodeCleaner(TestCase):
    def test_clean_labels(self):
        self.graphClient = DefaultGraphAccessor(GraphClient(server_number=1))

        node = self.graphClient.find_node_by_id(16)
        self.assertEqual(NodeCleaner.clean_labels(node), [u'software', u'background knowledge', u'WikiData'])
        node = self.graphClient.find_node_by_id(177777)
        print node
        self.assertEqual(NodeCleaner.clean_labels(node), [u'background knowledge', u'WikiData'])

    def test_construct_property_set(self):
        node_list = [{"aaa": 1, "ccc": 6}, {"aaa": 2, "bbb": 3, "ccc": 6}, {"ccc": 4, "ddd": 5}]
        result = {"aaa", "ccc"}
        self.assertEqual(construct_property_set(node_list), result)
        print construct_property_set(node_list)

    def test_rename_property(self):
        self.graphClient = DefaultGraphAccessor(GraphClient(server_number=0))
        node_list = self.graphClient.find_by_name_property("awesome item", "acl9")
        result = rename_property(node_list)
        print result

    def test_public_labels_name(self):
        t = PUBLIC_LABELS
        self.assertIsNotNone(t)
        print(t)
示例#2
0
    def start_import_for_api_entity(self, linking_result_file):
        graph_client = GraphClient(server_number=4)
        default_graph_client = DefaultGraphAccessor(graph_client)
        api_entity_graph_client = DomainEntityAccessor(graph_client)
        api_entity_graph_client.delete_all_api_entity_to_wikipedia_relation()
        print("delete all old may link relation complete")

        with open(linking_result_file, 'r') as f:
            link_relation_list = json.load(f)

        for each in link_relation_list:
            api_entity_id = each['api_entity_id']
            wikipedia_entity_id = each['wikipedia_entity_id']
            if api_entity_id is None or wikipedia_entity_id is None:
                continue
            api_entity = api_entity_graph_client.find_api_entity_node_by_id(api_entity_id)
            if api_entity is None:
                continue
            wikipedia_entity = default_graph_client.find_node_by_id(wikipedia_entity_id)
            if wikipedia_entity is None:
                continue
            api_entity_graph_client.create_entity_to_general_concept_relation(api_entity, wikipedia_entity)
class TestGraphClient(TestCase):
    graphClient = None

    def setUp(self):
        self.graphClient = DefaultGraphAccessor(GraphClient())
        self.nodeCleaner = NodeCleaner()

    def test_get_max_id_for_node(self):
        self.assertEqual(self.graphClient.get_max_id_for_node(), 697753)

    def test_get_adjacent_node_id_list(self):
        self.assertEqual(self.graphClient.get_adjacent_node_id_list(66666666),
                         [])

        correct = [64289, 52628, 62565]

        self.assertEqual(self.graphClient.get_adjacent_node_id_list(7899),
                         correct)

    def test_get_node_name_by_id(self):
        self.assertEqual(self.graphClient.get_node_name_by_id(66666666), None)
        self.assertEqual(self.graphClient.get_node_name_by_id(3444),
                         "Adobe Device Central")

    def test_expand_node_for_directly_adjacent_nodes_to_subgraph(self):
        # self.assertEqual(self.graphClient.expand_node_for_adjacent_nodes_to_subgraph(3444),
        #                  "Adobe Device Central")
        pass

    def test_find_by_alias_name_property_exactly_match_from_label_limit_one(
            self):
        self.assertEqual(
            self.graphClient.find_one_by_alias_name_property(
                "entity", "Adobe Device Central"), None)
        interface = self.graphClient.find_one_by_alias_name_property(
            "api", "Interface PrintGraphics")
        self.assertEqual(93008, self.graphClient.get_id_for_node(interface))

    def test_find_by_alias_name_property(self):
        self.assertEqual(
            self.graphClient.find_by_alias_name_property(
                "entity", "Adobe Device Central"), [])
        interfaces = self.graphClient.find_by_alias_name_property(
            "api", "Interface PrintGraphics")
        self.assertEqual(len(interfaces), 1)
        self.assertEqual(93008,
                         self.graphClient.get_id_for_node(interfaces[0]))

    def test_get_relation_by_relation_id(self):
        relation = self.graphClient.get_relation_by_relation_id(470129)
        self.assertIsNone(relation)

        relation = self.graphClient.get_relation_by_relation_id(122211)

        self.assertEqual(122211, self.graphClient.get_id_for_node(relation))
        self.assertEqual(
            91, self.graphClient.get_id_for_node(relation.start_node()))
        self.assertEqual(29390,
                         self.graphClient.get_id_for_node(relation.end_node()))

        subgraph = self.graphClient.get_relations_between_two_nodes_in_subgraph(
            246029, 246030)
        relations_json = []
        for r in subgraph.relationships():
            r = {
                "id": self.graphClient.get_id_for_node(relation),
                "name": relation.type(),
                "start_id":
                self.graphClient.get_start_id_for_relation(relation),
                "end_id": self.graphClient.get_end_id_for_relation(relation)
            }
            print r
        subgraph = self.graphClient.get_relations_between_two_nodes_in_subgraph(
            246029, 246033)
        self.assertEqual(subgraph, None)

    def test_find_node_by_id(self):
        node = self.graphClient.find_node_by_id(5444)
        self.assertEqual(5444, self.graphClient.get_id_for_node(node))

    def test_search_nodes_by_name(self):
        nodes = self.graphClient.search_nodes_by_name("java")
        count = 0
        for n in nodes:
            count = count + 1
        self.assertEqual(10, count)

        nodes = self.graphClient.search_nodes_by_name("String buffer()")

        count = 0
        for n in nodes:
            count = count + 1
        self.assertEqual(10, count)

    def test_search_nodes_by_name_in_subgraph(self):
        subgraph = self.graphClient.search_nodes_by_name_in_subgraph("java")
        count = 0
        for n in subgraph.nodes():
            count = count + 1
        self.assertEqual(10, count)

        subgraph = self.graphClient.search_nodes_by_name_in_subgraph(
            "String buffer()")
        count = 0
        if subgraph is not None:
            for n in subgraph.nodes():
                count = count + 1
        self.assertEqual(10, count)

    def test_get_relations_between_two_nodes_in_subgraph(self):
        subgraph = self.graphClient.get_relations_between_two_nodes_in_subgraph(
            48, 3600)
        self.assertEqual(None, subgraph)

        subgraph = self.graphClient.get_relations_between_two_nodes_in_subgraph(
            48, 3643)

        self.assertEqual(2, len(subgraph.nodes()))
        self.assertEqual(1, len(subgraph.relationships()))

    def test_get_relations_between_two_nodes(self):
        record_list = self.graphClient.get_relations_between_two_nodes(
            48, 3600)
        count = 0
        for n in record_list:
            count = count + 1
        self.assertEqual(0, count)

        record_list = self.graphClient.get_relations_between_two_nodes_in_subgraph(
            48, 3643)

        count = 0
        for n in record_list:
            count = count + 1

        self.assertEqual(1, count)

    def test_cleaner(self):
        node = self.graphClient.find_node_by_id(444)
        self.assertEqual(self.nodeCleaner.get_clean_node_name(node),
                         "fake news")

        node = self.graphClient.find_node_by_id(4444)
        self.assertEqual(self.nodeCleaner.get_clean_node_name(node), "")

        self.assertEqual(self.graphClient.get_id_for_node(Node("lll", a=3)),
                         -1)
        self.assertEqual(self.graphClient.get_id_for_node(node), 4444)

    def test_get_shortest_path_to_name(self):
        name = self.graphClient.get_node_name_by_id(8000)
        subraph = self.graphClient.get_shortest_path_to_name_in_subgraph(
            444, name)
        print subraph

    def test_get_shortest_path(self):
        record_list = self.graphClient.get_shortest_path(444,
                                                         8000,
                                                         max_degree=2)
        self.assertEqual(0, count_record_list(record_list))

        record_list = self.graphClient.get_shortest_path(444, 8000)
        self.assertNotEqual(None, record_list)
        self.assertEqual(1, count_record_list(record_list))

        subgraph = self.graphClient.get_shortest_path_in_subgraph(444,
                                                                  8000,
                                                                  max_degree=2)
        self.assertEqual(None, subgraph)

        subgraph = self.graphClient.get_shortest_path_in_subgraph(444,
                                                                  8000,
                                                                  max_degree=6)
        self.assertNotEqual(None, subgraph)
        self.assertEqual(len(subgraph.nodes()), 7)
        self.assertEqual(len(subgraph.relationships()), 6)
        print subgraph

    def test_get_newest_nodes(self):
        node_list = self.graphClient.get_newest_nodes(10)
        self.assertEqual(10, len(node_list))
        print(node_list)
        graphJsonParser = GraphJsonParser()
        returns = graphJsonParser.parse_node_list_to_json(node_list)
        print(returns)
class OldQuestionAnswerSystem:
    def __init__(self):
        self.question_preprossor = QuestionPreprossor()
        self.question_analyzer = QuestionAnalyzer()
        self.candidate_answer_generator = CandidateAnswerSetGenerator()
        # self.answer_generator = AnswerGenerator()
        self.answer_generator = None
        self.client = DefaultGraphAccessor(GraphClient())

    def simple_answer(self, question_text):
        '''
        get a simple answer string
        :param question_text: question text
        :return: a simple answer string
        '''
        # todo complete this answer method
        return "I can't answer this question"

    def full_answer(self, question_text, top_number=10):
        '''
        get a full answer set for the question
        :param question_text: question text
        :return:  a full answer set
        '''
        question = self.question_preprossor.preprosse_question(
            question_text=question_text)
        question = self.question_analyzer.analyze_question(question=question)
        question.print_to_console()
        candidate_answer_set = self.candidate_answer_generator.generate_candidate_answer_set(
            question=question)
        answer_set = self.answer_generator.generate_answer_set(
            question=question, candidate_answer_set=candidate_answer_set)
        return answer_set

    def fake_full_answer(self, question_text, top_number=10):
        '''
        get a full answer set for the question
        :param question_text: question text
        :return:  a full answer set
        '''
        ## todo fix this fake ful answer
        answer_set = AnswerSet(answer_list=[])
        return answer_set

        answer_list = []
        if 'what is java' in question_text.lower():
            node = self.client.find_node_by_id(678253)
            answer = Answer(self.get_entity_name(node), [node], 0.9)
            answer_list.append(answer)
            node = self.client.find_node_by_id(676793)
            answer = Answer(self.get_entity_name(node), [node], 0.7)
            answer_list.append(answer)
            node = self.client.find_node_by_id(7686)
            answer = Answer(self.get_entity_name(node), [node], 0.6)
            answer_list.append(answer)
            node = self.client.find_node_by_id(19204)
            answer = Answer(self.get_entity_name(node), [node], 0.4)
            answer_list.append(answer)
            node = self.client.find_node_by_id(821488)
            answer = Answer(self.get_entity_name(node), [node], 0.2)
            answer_list.append(answer)
            time.sleep(4)
        elif 'which library can process json' in question_text.lower():
            node = self.client.find_node_by_id(1518)
            answer = Answer(self.get_entity_name(node), [node], 0.8)
            answer_list.append(answer)
            node = self.client.find_node_by_id(698753)
            answer = Answer(self.get_entity_name(node), [node], 0.7)
            answer_list.append(answer)
            node = self.client.find_node_by_id(699419)
            answer = Answer(self.get_entity_name(node), [node], 0.6)
            answer_list.append(answer)
            node = self.client.find_node_by_id(700229)
            answer = Answer(self.get_entity_name(node), [node], 0.5)
            answer_list.append(answer)
            time.sleep(5)
        elif 'what string building api is thread safe' in question_text.lower(
        ):
            node = self.client.find_node_by_id(659611)
            answer = Answer(self.get_entity_name(node), [node], 0.8)
            answer_list.append(answer)
            node = self.client.find_node_by_id(52741)
            answer = Answer(self.get_entity_name(node), [node], 0.6)
            answer_list.append(answer)
            node = self.client.find_node_by_id(659612)
            answer = Answer(self.get_entity_name(node), [node], 0.1)
            answer_list.append(answer)
            time.sleep(7)
        elif 'where can i find apache' in question_text.lower():
            node = self.client.find_node_by_id(692324)
            answer = Answer(self.get_entity_name(node), [node], 0.8)
            answer_list.append(answer)
            node = self.client.find_node_by_id(1465147)
            answer = Answer(self.get_entity_name(node), [node], 0.6)
            answer_list.append(answer)
            node = self.client.find_node_by_id(2217)
            answer = Answer(self.get_entity_name(node), [node], 0.1)
            answer_list.append(answer)
            node = self.client.find_node_by_id(2805)
            answer = Answer(self.get_entity_name(node), [node], 0.1)
            answer_list.append(answer)
            time.sleep(3)
        answer_set = AnswerSet(answer_list=answer_list)
        return answer_set

    def get_entity_name(self, node):
        if node.has_key('name'):
            name = node.get('name')
        elif node.has_key('labels_en'):
            name = node.get('labels_en')
        else:
            name = ''
        return name
class WikiAliasDBImporter:
    def __init__(self):
        self.graphClient = None
        self.session = None

    def init(self):
        self.graphClient = DefaultGraphAccessor(GraphClient(server_number=4))
        self.session = EngineFactory.create_session()
        print("init complete")

    def clean_table(self):
        WikipediaEntityName.delete_all(self.session)
        WikipediaEntityNameToWikipediaMapping.delete_all(self.session)

        print("delete all exist table")

    def start_import_wiki_aliases_to_db(self):
        label = "wikipedia"
        wiki_nodes = self.graphClient.get_all_nodes_by_label(label)

        for node in wiki_nodes:
            node_id = self.graphClient.get_id_for_node(node)
            # print ('node_id: %r', node_id)
            # name, site_enwiki, labels_ = ''
            name_set = set([])
            if 'name' in dict(node):
                # print ("name: %r", node['name'])
                if isinstance(node['name'], list):
                    for each in node['name']:
                        name_set.add(each)
                else:
                    name_set.add(node['name'])
            if 'site:enwiki' in dict(node):
                # print ('site_enwiki: %s', node['site:enwiki'])
                if isinstance(node['site:enwiki'], list):
                    for each in node['site:enwiki']:
                        title = URLUtil.parse_url_to_title(each)
                        # print ('site_name: %r', title)
                        name_set.add(title)
                else:
                    title = URLUtil.parse_url_to_title(node["site:enwiki"])
                    # print ('site_name: %r', title)
                    name_set.add(title)
            if 'labels_en' in dict(node):
                # print( "labels_en: ", node['labels_en'])
                if isinstance(node['labels_en'], list):
                    for each in node['labels_en']:
                        name_set.add(each)
                else:
                    name_set.add(node['labels_en'])
            if 'aliases_en' in dict(node):
                # print("aliases_en: ", node['aliases_en'])
                for each in node['aliases_en']:
                    name_set.add(each)
            # print (name_set)
            for name in name_set:
                try:
                    wikipedia_entity_name = WikipediaEntityName(
                        node_id, str(name))
                    wikipedia_entity_name.find_or_create(self.session,
                                                         autocommit=True)
                except Exception:
                    traceback.print_exc()
            # self.session.commit()
        self.session.commit()

    def start_generate_wiki_entity_text_map(self):
        wikipedia_entity_name_data = WikipediaEntityName.get_all_wikipedia_names(
            self.session)
        kg_id_list = set([])
        for each in wikipedia_entity_name_data:
            if each is not None:
                kg_id_list.add(each.kg_id)
        # print kg_id_list
        for kg_id in kg_id_list:
            node = self.graphClient.find_node_by_id(kg_id)
            if node is not None:
                if "site:enwiki" in dict(node):
                    title = URLUtil.parse_url_to_title(node["site:enwiki"])
                    wikipedia_doc = WikipediaDocument.get_document_by_wikipedia_title(
                        self.session, title)
                    if wikipedia_doc is not None:
                        wikipedia_id = wikipedia_doc.id
                        wiki_name_to_wikipedia_mapping = WikipediaEntityNameToWikipediaMapping(
                            kg_id, wikipedia_id)
                        wiki_name_to_wikipedia_mapping.find_or_create(
                            self.session, autocommit=False)
        self.session.commit()

    def start_import(self):
        self.init()
        self.clean_table()
        self.start_import_wiki_aliases_to_db()
        self.start_generate_wiki_entity_text_map()