コード例 #1
0
ファイル: base.py プロジェクト: FudanSELab/kgdt
 def __init__(self):
     self.__name2component = {}
     self.__component_order = []
     self.__graph_data = GraphData()
     self.__doc_collection = MultiFieldDocumentCollection()
     self.__before_run_component_listeners = {}
     self.__after_run_component_listeners = {}
コード例 #2
0
    def train(cls, graph_data: GraphData or str or Path, *properties):
        """
        train the kg name searcher model from a graph data object by specifying name properties.
        :param properties: the properties that need to be searched on, could be more than one. e.g., "name","qualified_name","labels_en"
        :param graph_data:the path of graph data.
        :return:
        """
        # todo: add some config arguments, to control whether lower the case, split the words.
        if graph_data == None:
            raise Exception("Input GraphData object not exist")

        graph_data_source = None
        if type(graph_data) == str:
            graph_data_source: GraphData = GraphData.load(graph_data)
        if type(graph_data) == Path:
            graph_data_source: GraphData = GraphData.load(str(graph_data))
        if type(graph_data) == GraphData:
            graph_data_source = graph_data

        if graph_data_source is None:
            raise Exception("can't find the graph data")

        searcher = cls()
        searcher.start_training(graph_data_source, *properties)
        return searcher
コード例 #3
0
ファイル: base.py プロジェクト: FudanSELab/kgdt
    def load_graph(self, graph_data_path):
        self.__graph_data = GraphData.load(graph_data_path)
        # update component graph data
        for component_name in self.__component_order:
            component: Component = self.__name2component[component_name]
            component.set_graph_data(self.__graph_data)

        print("load graph")
コード例 #4
0
ファイル: neo4j.py プロジェクト: FudanSELab/kgdt
    def node_csv2graphdata(file, graph: GraphData = None, csv_id=GraphData.DEFAULT_KEY_NODE_ID, csv_labels =GraphData.DEFAULT_KEY_NODE_LABELS):
        '''
        :param file:  节点csv文件的全路径
        :param graph: 将要导入csv的graph,将要导入的graphdata,没有传参则新建
        :param csv_id: csv文件id所在列的列名,默认是id
        :param csv_labels: csv文件labels所在列的列名,默认是labels
        :return: 导入节点后的graphdata
        '''

        if not graph:
            graph = GraphData()
        count = 0
        with open(file, 'r', encoding="utf-8") as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                row = dict(row)
                node_id = None
                node_labels = set([])
                node_dic = {}
                for row_k, row_v in row.items():
                    if row_k == csv_id:
                        node_id = eval(row_v)
                        continue
                    if row_k == csv_labels:
                        node_labels = eval(row_v)
                        continue
                    if row_v == '':
                        continue
                    if row_v[0] == '[':
                        try:
                            row_v_list = eval(row_v)
                            node_dic[row_k] = row_v_list
                        except BaseException:
                            node_dic[row_k] = row_v
                        continue
                    try:
                        row_v_int = int(row_v)
                        node_dic[row_k] = row_v_int
                    except BaseException:
                        node_dic[row_k] = row_v
                result = graph.add_node(node_labels, node_dic, node_id)
                if result != -1:
                    count = count + 1
        print("从", file, "一共导入graphdata节点个数:   ", count)
        return graph
コード例 #5
0
ファイル: neo4j.py プロジェクト: FudanSELab/kgdt
    def import_all_graph_data(self, graph_data: GraphData, clear=True):
        """
        import all data in one GraphData into neo4j and create index on node
        :param graph_data:
        :param clear: clear the graph content, default is not clear the graph contain
        :return:
        """
        index_accessor = IndexGraphAccessor(self.graph_accessor)
        index_accessor.create_index(label=self.DEFAULT_LABEL, property_name=self.DEFAULT_PRIMARY_KEY)

        if clear:
            self.graph_accessor.delete_all_relations()
            self.graph_accessor.delete_all_nodes()

        # todo: this is slow, need to speed up, maybe not commit on every step
        all_node_ids = graph_data.get_node_ids()
        for node_id in all_node_ids:
            ## todo: fix this by not using 'properties','labels'
            node_info_dict = graph_data.get_node_info_dict(node_id)
            properties = node_info_dict['properties']
            labels = node_info_dict['labels']
            self.import_one_entity(node_id, properties, labels)

        print("all entity imported")
        relations = graph_data.get_relations()
        for r in relations:
            start_node_id, r_name, end_node_id = r
            start_node = self.graph_accessor.find_node(primary_label=self.DEFAULT_LABEL,
                                                       primary_property=self.DEFAULT_PRIMARY_KEY,
                                                       primary_property_value=start_node_id)
            end_node = self.graph_accessor.find_node(primary_label=self.DEFAULT_LABEL,
                                                     primary_property=self.DEFAULT_PRIMARY_KEY,
                                                     primary_property_value=end_node_id)

            if start_node is not None and end_node is not None:
                try:
                    self.graph_accessor.create_relation_without_duplicate(start_node, r_name, end_node)
                except Exception as e:
                    traceback.print_exc()
            else:
                print("fail create relation because start node or end node is none.")
        print("all relation imported")

        print("all graph data import finish")
コード例 #6
0
 def __init__(self, graph_data=None, doc_collection=None):
     if graph_data is not None:
         self.graph_data = graph_data
     else:
         self.graph_data = GraphData()
     if doc_collection is not None:
         self.doc_collection = doc_collection
     else:
         self.doc_collection = MultiFieldDocumentCollection()
     self.__before_run_listeners = []
     self.__after_run_listeners = []
コード例 #7
0
    def start_training(self, graph_data: GraphData, *properties):
        """
        start train the kg name searcher model from a graph data object by specifying name properties.
        :param properties: the properties that need to be searched on. e.g., "name","qualified_name","labels_en"
        :param graph_data: the GraphData instance
        :return:
        """
        # todo: add some config arguments, to control whether lower the case, split the words.
        self.clear()
        for node_id in graph_data.get_node_ids():
            node_properties = graph_data.get_properties_for_node(
                node_id=node_id)
            for property_name in properties:
                property_value = node_properties.get(property_name, None)
                if not property_value:
                    continue
                if type(property_value) == list or type(property_value) == set:
                    iterable_property_values = property_value
                    for single_value in iterable_property_values:
                        self.add_from_property_value(single_value, node_id)

                else:
                    single_value = property_value
                    self.add_from_property_value(single_value, node_id)
コード例 #8
0
ファイル: neo4j.py プロジェクト: FudanSELab/kgdt
    def export_all_graph_data(self, graph, node_label):
        accessor = DataExporterAccessor(graph=graph)
        nodes = accessor.get_all_nodes(node_label=node_label)
        graph_data = GraphData()

        for node in nodes:
            labels = [label for label in node.labels]
            graph_data.add_node(node_id=node.identity, node_labels=labels, node_properties=dict(node))

        print("load entity complete, num=%d" % len(nodes))
        relations = accessor.get_all_relation(node_label=node_label)
        print("load relation complete,num=%d" % len(relations))
        graph_data.set_relations(relations=relations)

        return graph_data
コード例 #9
0
ファイル: neo4j.py プロジェクト: FudanSELab/kgdt
 def relation_csv2graphdata(file, graph=None, start_name=GraphData.DEFAULT_KEY_RELATION_START_ID,
                            relation_type_name=GraphData.DEFAULT_KEY_RELATION_TYPE, end_name=GraphData.DEFAULT_KEY_RELATION_END_ID):
     '''
     :param file: 关系csv文件的全路径
     :param graph: 将要导入的graphdata,没有传参则新建
     :param start_name: csv文件关系开始点ID那一列的列名,默认是startId
     :param relation_type_name: csv文件关系类型那一列的列名,默认是 relationType
     :param end_name: csv文件关系结束点ID那一列的列列名,默认是endId
     :return: 导入完成的graphdata
     '''
     count = 0
     if not graph:
         return GraphData()
     with open(file, 'r', encoding="utf-8") as csvfile:
         reader = csv.DictReader(csvfile)
         for row in reader:
             row = dict(row)
             if row[start_name] != '' and row[relation_type_name] != '' and row[end_name] != '':
                 result = graph.add_relation(int(row[start_name]), row[relation_type_name], int(row[end_name]))
                 if result:
                     count = count + 1
     print("从", file, "一共导入graphdata关系个数:   ", count)
     return graph
コード例 #10
0
ファイル: test_graph.py プロジェクト: FudanSELab/kgdt
    def test_merge(self):
        graph_data = GraphData()
        graph_data.create_index_on_property("qualified_name", "alias")

        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.add"})
        graph_data.add_node({"override method"},
                            {"qualified_name": "ArrayList.pop"})
        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.remove"})
        graph_data.add_node({"method"}, {
            "qualified_name": "ArrayList.clear",
            "alias": ["clear"]
        })

        graph_data.merge_node(node_labels=["method", "merge"],
                              node_properties={
                                  "qualified_name": "ArrayList.clear",
                                  "alias": ["clear", "clear1"]
                              },
                              primary_property_name="qualified_name")
コード例 #11
0
ファイル: test_graph.py プロジェクト: FudanSELab/kgdt
    def test_get_graph_with_property(self):
        graph_data = GraphData()

        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.add"})
        graph_data.add_node({"override method"},
                            {"qualified_name": "ArrayList.pop"})
        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.remove"})
        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.clear"})

        # print(graph_data.get_node_ids())
        # print(graph_data.get_relation_pairs_with_type())

        graph_data.add_relation_with_property(1,
                                              "related to",
                                              2,
                                              extra_info_key="as")
        graph_data.add_relation_with_property(1,
                                              "related to",
                                              3,
                                              extra_info_key="ab")
        graph_data.add_relation_with_property(1,
                                              "related to",
                                              4,
                                              extra_info_key="cs")
        graph_data.add_relation_with_property(2,
                                              "related to",
                                              3,
                                              extra_info_key="ca")
        graph_data.add_relation_with_property(3, "related to", 4)

        print(graph_data.get_relations(1, "related to"))
        print("get relation by type")
        print(graph_data.get_relations(relation_type="related to"))
        t = graph_data.get_edge_extra_info(1,
                                           2,
                                           "related to",
                                           extra_key="extra_info_key")
        print(t)
コード例 #12
0
ファイル: test_graph.py プロジェクト: FudanSELab/kgdt
 def test_exist_relation(self):
     graph_data = GraphData()
     graph_data.add_node({"method", "entity"}, {
         "qualified_name": "ArrayList.remove",
         "version": "1.0"
     })
     graph_data.add_node({"method", "entity"}, {
         "qualified_name": "LinkedList.add",
         "version": "1.0"
     })
     graph_data.add_node({"class", "entity"}, {
         "qualified_name": "ArrayList",
         "version": "1.0"
     })
     graph_data.add_relation(3, "hasMethod", 1)
     graph_data.add_relation(1, "belongTo", 3)
     self.assertEqual(graph_data.exist_any_relation(3, 1), True)
     self.assertEqual(graph_data.exist_any_relation(2, 1), False)
     new_relations = {(3, 'hasMethod', 1), (1, 'belongTo', 3)}
     self.assertEqual(graph_data.get_all_relations(1, 3), new_relations)
コード例 #13
0
ファイル: test_graph.py プロジェクト: FudanSELab/kgdt
    def test_get_graph(self):
        graph_data = GraphData()

        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.add"})
        graph_data.add_node({"override method"},
                            {"qualified_name": "ArrayList.pop"})
        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.remove"})
        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.clear"})

        print(graph_data.get_node_ids())
        print(graph_data.get_relation_pairs_with_type())

        graph_data.add_relation(1, "related to", 2)
        graph_data.add_relation(1, "related to", 3)
        graph_data.add_relation(1, "related to", 4)
        graph_data.add_relation(2, "related to", 3)
        graph_data.add_relation(3, "related to", 4)

        print(graph_data.get_relations(1, "related to"))
        print("get relation by type")
        print(graph_data.get_relations(relation_type="related to"))
コード例 #14
0
ファイル: test_graph.py プロジェクト: FudanSELab/kgdt
    def test_update_node(self):
        graph_data = GraphData()
        graph_data.add_node({"method", "entity"}, {
            "qualified_name": "ArrayList.remove",
            "version": "1.0"
        })
        graph_data.update_node_property_by_node_id(
            1, {
                "qualified_name": "ArrayList.remove",
                "version": "2.0",
                "parameter_num": 1
            })
        new_property = {
            "qualified_name": "ArrayList.remove",
            "version": "2.0",
            "parameter_num": 1
        }
        new_label = {"method", "entity"}
        self.assertEqual(
            graph_data.get_node_info_dict(1)[
                GraphData.DEFAULT_KEY_NODE_PROPERTIES], new_property)
        self.assertEqual(
            graph_data.get_node_info_dict(1)[
                GraphData.DEFAULT_KEY_NODE_LABELS], new_label)

        graph_data.update_node_property_value_by_node_id(
            1, "qualified_name", "ArrayList.add")
        new_property = {
            "qualified_name": "ArrayList.add",
            "version": "2.0",
            "parameter_num": 1
        }
        new_label = {"method", "entity"}
        self.assertEqual(
            graph_data.get_node_info_dict(1)[
                GraphData.DEFAULT_KEY_NODE_PROPERTIES], new_property)
        self.assertEqual(
            graph_data.get_node_info_dict(1)[
                GraphData.DEFAULT_KEY_NODE_LABELS], new_label)

        graph_data.update_node_by_node_id(1, {"class", "entity"}, {
            "qualified_name": "ArrayList.add",
            "version": "2.0",
            "parameter_num": 2
        })
        new_property = {
            "qualified_name": "ArrayList.add",
            "version": "2.0",
            "parameter_num": 2
        }
        new_label = {"method", "entity", "class"}
        self.assertEqual(
            graph_data.get_node_info_dict(1)[
                GraphData.DEFAULT_KEY_NODE_PROPERTIES], new_property)
        self.assertEqual(
            graph_data.get_node_info_dict(1)[
                GraphData.DEFAULT_KEY_NODE_LABELS], new_label)
コード例 #15
0
ファイル: test_graph.py プロジェクト: FudanSELab/kgdt
    def test_save_and_load(self):
        graph_data = GraphData()
        graph_data.create_index_on_property("qualified_name", "alias")

        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.add"})
        graph_data.add_node({"override method"},
                            {"qualified_name": "ArrayList.pop"})

        graph_data.save("test.graph")
        graph_data: GraphData = GraphData.load("test.graph")
        self.assertEqual(graph_data.get_node_num(), 2)
コード例 #16
0
ファイル: test_graph.py プロジェクト: FudanSELab/kgdt
    def test_remove_node(self):
        graph_data = GraphData()
        graph_data.create_index_on_property("qualified_name", "alias")

        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.add"})
        graph_data.add_node({"override method"},
                            {"qualified_name": "ArrayList.pop"})
        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.remove"})
        graph_data.add_node({"method"}, {
            "qualified_name": "ArrayList.clear",
            "alias": ["clear"]
        })
        graph_data.add_node({"method"}, {
            "qualified_name": "List.clear",
            "alias": ["clear", "List.clear", "List clear"]
        })
        graph_data.add_relation(1, "related to", 2)
        graph_data.add_relation(1, "related to", 3)
        graph_data.add_relation(1, "related to", 4)
        graph_data.add_relation(2, "related to", 3)
        graph_data.add_relation(3, "related to", 4)

        result = graph_data.remove_node(node_id=1)
        self.assertIsNotNone(result)

        self.assertIsNone(graph_data.get_node_info_dict(node_id=1))
コード例 #17
0
ファイル: test_graph.py プロジェクト: FudanSELab/kgdt
    def test_find_nodes_by_properties(self):
        graph_data = GraphData()
        graph_data.create_index_on_property("qualified_name", "alias")

        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.add"})
        graph_data.add_node({"override method"},
                            {"qualified_name": "ArrayList.pop"})
        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.remove"})
        graph_data.add_node({"method"}, {
            "qualified_name": "ArrayList.clear",
            "alias": ["clear"]
        })
        graph_data.add_node({"method"}, {
            "qualified_name": "List.clear",
            "alias": ["clear", "List.clear", "List clear"]
        })

        match_nodes = graph_data.find_nodes_by_property(
            property_name="qualified_name", property_value="List.clear")
        print(match_nodes)
        self.assertIsNotNone(match_nodes)
        self.assertEqual(len(match_nodes), 1)
        self.assertEqual(match_nodes[0][GraphData.DEFAULT_KEY_NODE_ID], 5)

        match_nodes = graph_data.find_nodes_by_property(property_name="alias",
                                                        property_value="clear")

        print(match_nodes)
        self.assertIsNotNone(match_nodes)
        self.assertEqual(len(match_nodes), 2)
コード例 #18
0
ファイル: neo4j.py プロジェクト: FudanSELab/kgdt
    def graphdata2csv(csv_folder, graph: GraphData, node_file_id=GraphData.DEFAULT_KEY_NODE_ID,
                      csv_labels=GraphData.DEFAULT_KEY_NODE_LABELS,
                      relation_file_start_id=GraphData.DEFAULT_KEY_RELATION_START_ID,
                      relation_file_end_id=GraphData.DEFAULT_KEY_RELATION_END_ID,
                      relation_file_type=GraphData.DEFAULT_KEY_RELATION_TYPE,
                      only_one_relation_file=True):
        '''
        :param csv_folder: 存放生产csv文件的文件夹路径
        :param graph: 将要导出的graphdata
        :param csv_id: 生成的csv文件id列的列名,默认是id
        :param csv_labels: 生成的csv文件labels列的列名,默认是labels
        :return: 无返回
        '''
        csvfilename2label = {}
        csvfilename2ids = {}
        csvfilename2property_name = {}

        ids = graph.get_node_ids()
        for id in ids:
            node = graph.get_node_info_dict(id)
            labels = node.get(GraphData.DEFAULT_KEY_NODE_LABELS)
            property_names = node.get(GraphData.DEFAULT_KEY_NODE_PROPERTIES).keys()
            csvfilename = '_'.join(labels)
            if labels not in csvfilename2label.values():
                csvfilename2label[csvfilename] = labels
                csvfilename2ids[csvfilename] = set([])
                csvfilename2property_name[csvfilename] = set([])
            csvfilename2ids[csvfilename].add(id)
            for property_name in property_names:
                csvfilename2property_name[csvfilename].add(property_name)
        for k, v in csvfilename2property_name.items():
            csvfilename2property_name[k] = list(v)
        node_count = 0
        for csvfilename, ids in csvfilename2ids.items():
            with open(os.path.join(csv_folder, '{}.{}'.format(csvfilename.replace(' ',''), 'csv')), 'w', newline='',
                           encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile, delimiter=',')
                first_node = True
                for id in ids:
                    node = graph.get_node_info_dict(id)
                    if node:
                        node_dic = {}
                        node_properties = node.get(GraphData.DEFAULT_KEY_NODE_PROPERTIES)
                        node_dic[node_file_id] = node.get(GraphData.DEFAULT_KEY_NODE_ID)
                        node_dic[csv_labels] = node.get(GraphData.DEFAULT_KEY_NODE_LABELS)
                        for property_name in csvfilename2property_name[csvfilename]:
                            node_dic[property_name] = node_properties.get(property_name)
                        if first_node:
                            writer.writerow(node_dic)
                        writer.writerow(node_dic.values())
                        node_count = node_count + 1
                        first_node = False
        print("一共导入csv的节点个数:   ", node_count)
        relation_types = graph.get_all_relation_types()
        relation_count = 0
        if only_one_relation_file:
            with open(os.path.join(csv_folder, '{}.{}'.format('relations', 'csv')), 'w', newline='',
                      encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile, delimiter=',')
                first_relation = True
                for relation_type in relation_types:
                    relation_pairs = graph.get_relations(relation_type=relation_type)
                    for relation_pair in relation_pairs:
                        relation_dic = {}
                        relation_dic[relation_file_start_id] = int(relation_pair[0])
                        relation_dic[relation_file_type] = relation_pair[1]
                        relation_dic[relation_file_end_id] = int(relation_pair[2])
                        if first_relation:
                            writer.writerow(relation_dic)
                            first_relation = False
                        writer.writerow(relation_dic.values())
                        relation_count = relation_count + 1
            print("一共导入csv的关系个数:   ", relation_count)
            print("一共生成", len(csvfilename2label), "个csv节点文件和", str(1), "个csv关系文件")
        else:
            for relation_type in relation_types:
                relation_pairs = graph.get_relations(relation_type=relation_type)
                with open(os.path.join(csv_folder, '{}.{}'.format(relation_type.replace(' ', ''), 'csv')), 'w', newline='',
                          encoding='utf-8') as csvfile:
                    writer = csv.writer(csvfile, delimiter=',')
                    first_relation = True
                    for relation_pair in relation_pairs:
                        relation_dic = {}
                        relation_dic[GraphData.DEFAULT_KEY_RELATION_START_ID] = int(relation_pair[0])
                        relation_dic[GraphData.DEFAULT_KEY_RELATION_TYPE] = relation_pair[1]
                        relation_dic[GraphData.DEFAULT_KEY_RELATION_END_ID] = int(relation_pair[2])
                        if first_relation:
                            writer.writerow(relation_dic)
                            first_relation = False
                        writer.writerow(relation_dic.values())
                        relation_count = relation_count + 1
            print("一共导入csv的关系个数:   ", relation_count)
            print("一共生成", len(csvfilename2label), "个csv节点文件和", len(relation_types), "个csv关系文件")
コード例 #19
0
ファイル: base.py プロジェクト: FudanSELab/kgdt
class KGBuildPipeline:
    def __init__(self):
        self.__name2component = {}
        self.__component_order = []
        self.__graph_data = GraphData()
        self.__doc_collection = MultiFieldDocumentCollection()
        self.__before_run_component_listeners = {}
        self.__after_run_component_listeners = {}

    def __repr__(self):
        return str(self.__component_order)

    def exist_component(self, component_name):
        """
        check whether the component exist in the Pipeline
        :param component_name: the name of component
        :return: True, exist, False, not exist.
        """
        if component_name in self.__component_order:
            return True
        return False

    def add_before_listener(self, component_name, listener: PipelineListener):
        """
        add a new PipelineListener running before a specific component
        :param component_name: the name of the component
        :param listener: the PipelineListener
        :return:
        """
        if not self.exist_component(component_name):
            raise ComponentNotExistError(component_name)

        if component_name not in self.__before_run_component_listeners:
            self.__before_run_component_listeners[component_name] = []
        self.__before_run_component_listeners[component_name].append(listener)

    def add_after_listener(self, component_name, listener: PipelineListener):
        if component_name not in self.__after_run_component_listeners:
            self.__after_run_component_listeners[component_name] = []
        self.__after_run_component_listeners[component_name].append(listener)

    def __get_component_order(self, name):
        """
        get the order of the specific component
        :param name: the specific component
        :return: the order start from 0 to num(component), -1. the specific component not exist
        """
        self.__component_order.append(name)

        for order, exist_component in enumerate(self.__component_order):
            if exist_component == name:
                return order
        return -1

    def __allocate_order_for_new_component(self, before=None, after=None):
        """
        try to allocate the right position for the new component
        :param before: the component of this new component must run before
        :param after: the component of this new component must run after
        :return: -1, can't not find a right order.
        """
        min_order = 0
        max_order = self.num_of_components()

        if before is not None:
            max_order = self.__get_component_order(before)
            if max_order == -1:
                max_order = self.num_of_components()

        if after is not None:
            min_order = self.__get_component_order(after) + 1
            if min_order == -1:
                min_order = 0
        if min_order > max_order:
            return -1
        return max_order

    def add_component(self,
                      name,
                      component: Component,
                      before=None,
                      after=None,
                      **config):
        """
        add a new component to this pipeline with given name. In a pipeline, the component name must be unique.
        :param after: the component name the this new component must run after
        :param before: the component name the this new component must run after
        :param name: the name of this new component
        :param component: the component instance
        :param config: the other config, save for update
        :return:
        """

        order = self.__allocate_order_for_new_component(before=before,
                                                        after=after)
        if order == -1:
            raise ComponentOrderError("Can't not find a right order for %s" %
                                      name)

        component.set_graph_data(self.__graph_data)
        component.set_doc_collection(self.__doc_collection)
        self.__name2component[name] = component

        self.__component_order.insert(order, name)

    def check(self):
        """
        check whether the components in the pipeline setting correct.
        e.g., the order of the component is wrong.
        the necessary component for a component to run is missing.
        :return: True the pipeline is correct.
        """
        current_entities = self.get_provided_entities()
        current_relations = self.get_provided_relations()
        current_document_fields = self.get_provided_document_fields()

        component_pairs = self.get_component_name_with_component_pair_by_order(
        )

        for component_name, component in component_pairs:
            missing_entities = component.dependent_entities(
            ) - current_entities
            if missing_entities != set():
                raise ComponentDependencyError(component_name,
                                               missing_entities)
            current_entities.update(component.provided_entities())

            missing_relations = component.dependent_relations(
            ) - current_relations
            if missing_entities != set():
                raise ComponentDependencyError(component_name,
                                               missing_relations)
            current_relations.update(component.provided_relations())

            missing_fields = component.dependent_document_fields(
            ) - current_document_fields
            if missing_fields != set():
                raise ComponentDependencyError(component_name, missing_fields)

            current_document_fields.update(
                component.provided_document_fields())

        return True

    def get_provided_document_fields(self):
        """
        get the provided entity type set for the pipeline from the current DocumentCollection. If the pipeline start from empty state,
        This method will return empty set
        :return:
        """
        return set(self.__doc_collection.get_field_set())

    def get_provided_relations(self):
        """
        get the provided relation type set for the pipeline from the current GraphData. If the pipeline start from empty state,
        This method will return empty set
        :return:
        """
        return self.__graph_data.get_all_relation_types()

    def get_provided_entities(self):
        """
        get the provided entity type set for the pipeline from the current GraphData. If the pipeline start from empty state,
        This method will return empty set
        :return:
        """
        return set(self.__graph_data.get_all_labels())

    def get_components_by_order(self):
        components = []
        for component_name in self.__component_order:
            component: Component = self.__name2component[component_name]
            components.append(component)
        return components

    def get_component_name_with_component_pair_by_order(self):
        components = []
        for component_name in self.__component_order:
            component: Component = self.__name2component[component_name]
            components.append((component_name, component))
        return components

    def run(self, **config):
        self.check()
        print("start running the pipeline")
        for component_name in self.__component_order:
            component: Component = self.__name2component[component_name]
            self.before_run_component(component_name, **config)
            component.before_run()
            component.run()
            component.after_run()
            self.after_run_component(component_name, **config)

        print("finish running the pipeline")

    def before_run_component(self, component_name, **config):
        print("start running with name=%r in the pipeline" % component_name)
        for listener in self.__before_run_component_listeners.get(
                component_name, []):
            listener.on_before_run_component(component_name, self, **config)

    def after_run_component(self, component_name, **config):
        print("finish running with name=%r in the pipeline\n" % component_name)
        for listener in self.__after_run_component_listeners.get(
                component_name, []):
            listener.on_after_run_component(component_name, self, **config)

    def save(self, graph_path=None, doc_path=None):
        """
        save the graph data object after all the building of all component
        :param doc_path: the path to save the DocumentCollection
        :param graph_path: the path to save the GraphData
        :return:
        """
        self.save_graph(path=graph_path)
        self.save_doc(path=doc_path)

    def save_graph(self, path):
        if path is None:
            return
        self.__graph_data.save(path)

    def save_doc(self, path):
        if path is None:
            return
        self.__doc_collection.save(path)

    def load_graph(self, graph_data_path):
        self.__graph_data = GraphData.load(graph_data_path)
        # update component graph data
        for component_name in self.__component_order:
            component: Component = self.__name2component[component_name]
            component.set_graph_data(self.__graph_data)

        print("load graph")

    def load_doc(self, document_collection_path):
        self.__doc_collection = MultiFieldDocumentCollection.load(
            document_collection_path)
        # update component doc_collection
        for component_name in self.__component_order:
            component: Component = self.__name2component[component_name]
            component.set_doc_collection(self.__doc_collection)

        print("load doc collection")

    def num_of_components(self):
        return len(self.__component_order)
コード例 #20
0
    def get_graph(self):
        graph_data = GraphData()

        graph_data.add_node({"method"}, {
            "qualified_name": "ArrayList.add()",
            "name": "ArrayList.add",
            "alias": ["ArrayList.add1", "add()", "add"]
        })
        graph_data.add_node({"method"}, {
            "qualified_name": "ArrayList.add(int)",
            "name": "ArrayList.add",
            "alias": ["ArrayList.add2", "add()", "add"]
        })

        graph_data.add_node({"override method"},
                            {"qualified_name": "ArrayList.pop"})
        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.remove"})
        graph_data.add_node({"method"}, {"qualified_name": "ArrayList.clear"})

        return graph_data