예제 #1
0
파일: neo4j.py 프로젝트: FudanSELab/kgdt
    def import_all_graph_data(self, graph_data: GraphData, clear=True):
        """
        import all data in one GraphData into neo4j and create index on node
        :param graph_data:
        :param clear: clear the graph content, default is not clear the graph contain
        :return:
        """
        index_accessor = IndexGraphAccessor(self.graph_accessor)
        index_accessor.create_index(label=self.DEFAULT_LABEL, property_name=self.DEFAULT_PRIMARY_KEY)

        if clear:
            self.graph_accessor.delete_all_relations()
            self.graph_accessor.delete_all_nodes()

        # todo: this is slow, need to speed up, maybe not commit on every step
        all_node_ids = graph_data.get_node_ids()
        for node_id in all_node_ids:
            ## todo: fix this by not using 'properties','labels'
            node_info_dict = graph_data.get_node_info_dict(node_id)
            properties = node_info_dict['properties']
            labels = node_info_dict['labels']
            self.import_one_entity(node_id, properties, labels)

        print("all entity imported")
        relations = graph_data.get_relations()
        for r in relations:
            start_node_id, r_name, end_node_id = r
            start_node = self.graph_accessor.find_node(primary_label=self.DEFAULT_LABEL,
                                                       primary_property=self.DEFAULT_PRIMARY_KEY,
                                                       primary_property_value=start_node_id)
            end_node = self.graph_accessor.find_node(primary_label=self.DEFAULT_LABEL,
                                                     primary_property=self.DEFAULT_PRIMARY_KEY,
                                                     primary_property_value=end_node_id)

            if start_node is not None and end_node is not None:
                try:
                    self.graph_accessor.create_relation_without_duplicate(start_node, r_name, end_node)
                except Exception as e:
                    traceback.print_exc()
            else:
                print("fail create relation because start node or end node is none.")
        print("all relation imported")

        print("all graph data import finish")
예제 #2
0
    def test_get_graph(self):
        graph_data = GraphData()

        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.add"})
        graph_data.add_node({"override method"},
                            {"qualified_name": "ArrayList.pop"})
        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.remove"})
        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.clear"})

        print(graph_data.get_node_ids())
        print(graph_data.get_relation_pairs_with_type())

        graph_data.add_relation(1, "related to", 2)
        graph_data.add_relation(1, "related to", 3)
        graph_data.add_relation(1, "related to", 4)
        graph_data.add_relation(2, "related to", 3)
        graph_data.add_relation(3, "related to", 4)

        print(graph_data.get_relations(1, "related to"))
        print("get relation by type")
        print(graph_data.get_relations(relation_type="related to"))
예제 #3
0
    def start_training(self, graph_data: GraphData, *properties):
        """
        start train the kg name searcher model from a graph data object by specifying name properties.
        :param properties: the properties that need to be searched on. e.g., "name","qualified_name","labels_en"
        :param graph_data: the GraphData instance
        :return:
        """
        # todo: add some config arguments, to control whether lower the case, split the words.
        self.clear()
        for node_id in graph_data.get_node_ids():
            node_properties = graph_data.get_properties_for_node(
                node_id=node_id)
            for property_name in properties:
                property_value = node_properties.get(property_name, None)
                if not property_value:
                    continue
                if type(property_value) == list or type(property_value) == set:
                    iterable_property_values = property_value
                    for single_value in iterable_property_values:
                        self.add_from_property_value(single_value, node_id)

                else:
                    single_value = property_value
                    self.add_from_property_value(single_value, node_id)
예제 #4
0
파일: neo4j.py 프로젝트: FudanSELab/kgdt
    def graphdata2csv(csv_folder, graph: GraphData, node_file_id=GraphData.DEFAULT_KEY_NODE_ID,
                      csv_labels=GraphData.DEFAULT_KEY_NODE_LABELS,
                      relation_file_start_id=GraphData.DEFAULT_KEY_RELATION_START_ID,
                      relation_file_end_id=GraphData.DEFAULT_KEY_RELATION_END_ID,
                      relation_file_type=GraphData.DEFAULT_KEY_RELATION_TYPE,
                      only_one_relation_file=True):
        '''
        :param csv_folder: 存放生产csv文件的文件夹路径
        :param graph: 将要导出的graphdata
        :param csv_id: 生成的csv文件id列的列名,默认是id
        :param csv_labels: 生成的csv文件labels列的列名,默认是labels
        :return: 无返回
        '''
        csvfilename2label = {}
        csvfilename2ids = {}
        csvfilename2property_name = {}

        ids = graph.get_node_ids()
        for id in ids:
            node = graph.get_node_info_dict(id)
            labels = node.get(GraphData.DEFAULT_KEY_NODE_LABELS)
            property_names = node.get(GraphData.DEFAULT_KEY_NODE_PROPERTIES).keys()
            csvfilename = '_'.join(labels)
            if labels not in csvfilename2label.values():
                csvfilename2label[csvfilename] = labels
                csvfilename2ids[csvfilename] = set([])
                csvfilename2property_name[csvfilename] = set([])
            csvfilename2ids[csvfilename].add(id)
            for property_name in property_names:
                csvfilename2property_name[csvfilename].add(property_name)
        for k, v in csvfilename2property_name.items():
            csvfilename2property_name[k] = list(v)
        node_count = 0
        for csvfilename, ids in csvfilename2ids.items():
            with open(os.path.join(csv_folder, '{}.{}'.format(csvfilename.replace(' ',''), 'csv')), 'w', newline='',
                           encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile, delimiter=',')
                first_node = True
                for id in ids:
                    node = graph.get_node_info_dict(id)
                    if node:
                        node_dic = {}
                        node_properties = node.get(GraphData.DEFAULT_KEY_NODE_PROPERTIES)
                        node_dic[node_file_id] = node.get(GraphData.DEFAULT_KEY_NODE_ID)
                        node_dic[csv_labels] = node.get(GraphData.DEFAULT_KEY_NODE_LABELS)
                        for property_name in csvfilename2property_name[csvfilename]:
                            node_dic[property_name] = node_properties.get(property_name)
                        if first_node:
                            writer.writerow(node_dic)
                        writer.writerow(node_dic.values())
                        node_count = node_count + 1
                        first_node = False
        print("一共导入csv的节点个数:   ", node_count)
        relation_types = graph.get_all_relation_types()
        relation_count = 0
        if only_one_relation_file:
            with open(os.path.join(csv_folder, '{}.{}'.format('relations', 'csv')), 'w', newline='',
                      encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile, delimiter=',')
                first_relation = True
                for relation_type in relation_types:
                    relation_pairs = graph.get_relations(relation_type=relation_type)
                    for relation_pair in relation_pairs:
                        relation_dic = {}
                        relation_dic[relation_file_start_id] = int(relation_pair[0])
                        relation_dic[relation_file_type] = relation_pair[1]
                        relation_dic[relation_file_end_id] = int(relation_pair[2])
                        if first_relation:
                            writer.writerow(relation_dic)
                            first_relation = False
                        writer.writerow(relation_dic.values())
                        relation_count = relation_count + 1
            print("一共导入csv的关系个数:   ", relation_count)
            print("一共生成", len(csvfilename2label), "个csv节点文件和", str(1), "个csv关系文件")
        else:
            for relation_type in relation_types:
                relation_pairs = graph.get_relations(relation_type=relation_type)
                with open(os.path.join(csv_folder, '{}.{}'.format(relation_type.replace(' ', ''), 'csv')), 'w', newline='',
                          encoding='utf-8') as csvfile:
                    writer = csv.writer(csvfile, delimiter=',')
                    first_relation = True
                    for relation_pair in relation_pairs:
                        relation_dic = {}
                        relation_dic[GraphData.DEFAULT_KEY_RELATION_START_ID] = int(relation_pair[0])
                        relation_dic[GraphData.DEFAULT_KEY_RELATION_TYPE] = relation_pair[1]
                        relation_dic[GraphData.DEFAULT_KEY_RELATION_END_ID] = int(relation_pair[2])
                        if first_relation:
                            writer.writerow(relation_dic)
                            first_relation = False
                        writer.writerow(relation_dic.values())
                        relation_count = relation_count + 1
            print("一共导入csv的关系个数:   ", relation_count)
            print("一共生成", len(csvfilename2label), "个csv节点文件和", len(relation_types), "个csv关系文件")