コード例 #1
0
 def create_multigraph(self, mg_dict):
     mg = MultiGraph("dblp", graphs_source=mg_dict)
     mg.sort_by(self.sort_gnx)
     mg.suspend_logger()
     check("./dataset/" + str(self.dataset_name) + "/pkl/")
     pickle.dump(mg, open("./dataset/" + str(self.dataset_name) + "/pkl/mg_" + self.dataset_name + ".pkl", "wb"))
     mg.wake_logger()
     return mg
コード例 #2
0
    def multi_graph_by_window(self, window_size=None, start_time=0):
        if start_time < 0 or start_time > self._number_of_times:
            self._logger.error("invalid start time = " + str(start_time) + ", total intervals = " + str(self._number_of_times))
            return
        # build base mg
        mg = MultiGraph(self._database_name + "window", graphs_source=self._edge_list_dict[self._times[0]],
                        directed=self._directed)
        for i in range(1, start_time):
            mg.add_edges(self._edge_list_dict[self._times[i]])

        window_size = window_size if window_size else self._number_of_times
        for i in range(start_time, self._number_of_times):
            mg.suspend_logger()
            temp = copy.deepcopy(mg)
            mg.wake_logger()
            yield temp

            to_remove = i - window_size
            if to_remove >= 0:
                mg.remove_edges(self._edge_list_dict[self._times[to_remove]])
            mg.add_edges(self._edge_list_dict[self._times[i]])
コード例 #3
0
ファイル: dataset.py プロジェクト: itayl13/bilinear-learning
    def _build_multi_graph(self):
        path_pkl = os.path.join(
            self._base_dir, PKL_DIR, self._params.DATASET_NAME + "_split_" +
            str(self._params.PERCENTAGE) + "_mg.pkl")
        if os.path.exists(path_pkl):
            return pickle.load(open(path_pkl, "rb"))
        multi_graph_dict = {}
        labels = {}
        label_to_idx = {}
        # open basic data csv (with all edges of all times)
        data_df = pd.read_csv(self._src_file_path)
        stop = data_df.shape[0] * self._params.PERCENTAGE

        for index, edge in data_df.iterrows():
            if index > stop:
                break
            # write edge to dictionary
            graph_id = str(edge[self._params.GRAPH_NAME_COL])
            src = str(edge[self._params.SRC_COL])
            dst = str(edge[self._params.DST_COL])
            multi_graph_dict[graph_id] = multi_graph_dict.get(
                graph_id, []) + [(src, dst)]
            label = edge[self._params.LABEL_COL]
            label_to_idx[label] = len(
                label_to_idx
            ) if label not in label_to_idx else label_to_idx[label]
            labels[graph_id] = label_to_idx[label]

        mg = MultiGraph(self._params.DATASET_NAME,
                        graphs_source=multi_graph_dict,
                        directed=self._params.DIRECTED,
                        logger=self._logger)
        idx_to_label = [
            l for l in sorted(label_to_idx, key=lambda x: label_to_idx[x])
        ]
        mg.suspend_logger()
        pickle.dump((mg, labels, label_to_idx, idx_to_label),
                    open(path_pkl, "wb"))
        mg.wake_logger()
        return mg, labels, label_to_idx, idx_to_label