コード例 #1
0
def read_graph_pyg(raw_dir,
                   add_inverse_edge=False,
                   additional_node_files=[],
                   additional_edge_files=[],
                   binary=False):

    if binary:
        # npz
        graph_list = read_binary_graph_raw(raw_dir, add_inverse_edge)
    else:
        # csv
        graph_list = read_csv_graph_raw(
            raw_dir,
            add_inverse_edge,
            additional_node_files=additional_node_files,
            additional_edge_files=additional_edge_files)

    pyg_graph_list = []

    print('Converting graphs into PyG objects...')

    for graph in tqdm(graph_list):
        g = Data()
        g.__num_nodes__ = graph['num_nodes']
        g.edge_index = torch.from_numpy(graph['edge_index'])

        del graph['num_nodes']
        del graph['edge_index']

        if graph['edge_feat'] is not None:
            g.edge_attr = torch.from_numpy(graph['edge_feat'])
            del graph['edge_feat']

        if graph['node_feat'] is not None:
            g.x = torch.from_numpy(graph['node_feat'])
            del graph['node_feat']

        for key in additional_node_files:
            g[key] = torch.from_numpy(graph[key])
            del graph[key]

        for key in additional_edge_files:
            g[key] = torch.from_numpy(graph[key])
            del graph[key]

        pyg_graph_list.append(g)

        add_order_info_01(g)  # DAGNN
        # length of longest path
        # layer ids start with 0 so max, gives actual path length and -1 is not necessary
        g.len_longest_path = float(torch.max(g._bi_layer_idx0).item())

    return pyg_graph_list
コード例 #2
0
def read_graph_pyg(raw_dir,
                   add_inverse_edge=False,
                   additional_node_files=[],
                   additional_edge_files=[],
                   binary=False):

    if binary:
        # npz
        graph_list = read_binary_graph_raw(raw_dir, add_inverse_edge)
    else:
        # csv
        graph_list = read_csv_graph_raw(
            raw_dir,
            add_inverse_edge,
            additional_node_files=additional_node_files,
            additional_edge_files=additional_edge_files)

    pyg_graph_list = []

    print('Converting graphs into PyG objects...')

    for graph in tqdm(graph_list):
        g = Data()
        g.__num_nodes__ = graph['num_nodes']
        g.edge_index = torch.from_numpy(graph['edge_index'])

        del graph['num_nodes']
        del graph['edge_index']

        if graph['edge_feat'] is not None:
            g.edge_attr = torch.from_numpy(graph['edge_feat'])
            del graph['edge_feat']

        if graph['node_feat'] is not None:
            g.x = torch.from_numpy(graph['node_feat'])
            del graph['node_feat']

        for key in additional_node_files:
            g[key] = torch.from_numpy(graph[key])
            del graph[key]

        for key in additional_edge_files:
            g[key] = torch.from_numpy(graph[key])
            del graph[key]

        pyg_graph_list.append(g)

    return pyg_graph_list
コード例 #3
0
def read_graph_dgl(raw_dir,
                   add_inverse_edge=False,
                   additional_node_files=[],
                   additional_edge_files=[],
                   binary=False):

    if binary:
        # npz
        graph_list = read_binary_graph_raw(raw_dir, add_inverse_edge)
    else:
        # csv
        graph_list = read_csv_graph_raw(
            raw_dir,
            add_inverse_edge,
            additional_node_files=additional_node_files,
            additional_edge_files=additional_edge_files)

    dgl_graph_list = []

    print('Converting graphs into DGL objects...')

    for graph in tqdm(graph_list):
        g = dgl.graph((graph['edge_index'][0], graph['edge_index'][1]),
                      num_nodes=graph['num_nodes'])

        if graph['edge_feat'] is not None:
            g.edata['feat'] = torch.from_numpy(graph['edge_feat'])

        if graph['node_feat'] is not None:
            g.ndata['feat'] = torch.from_numpy(graph['node_feat'])

        for key in additional_node_files:
            g.ndata[key[5:]] = torch.from_numpy(graph[key])

        for key in additional_edge_files:
            g.edata[key[5:]] = torch.from_numpy(graph[key])

        dgl_graph_list.append(g)

    return dgl_graph_list
コード例 #4
0
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        pre_processed_file_path = osp.join(processed_dir, 'data_processed')

        if osp.exists(pre_processed_file_path):
            # self.graph = torch.load(pre_processed_file_path, 'rb')
            self.graph = load_pickle(pre_processed_file_path)

        else:
            ### check download
            if self.binary:
                # npz format
                has_necessary_file_simple = osp.exists(
                    osp.join(self.root, 'raw',
                             'data.npz')) and (not self.is_hetero)
                has_necessary_file_hetero = osp.exists(
                    osp.join(self.root, 'raw',
                             'edge_index_dict.npz')) and self.is_hetero
            else:
                # csv file
                has_necessary_file_simple = osp.exists(
                    osp.join(self.root, 'raw',
                             'edge.csv.gz')) and (not self.is_hetero)
                has_necessary_file_hetero = osp.exists(
                    osp.join(self.root, 'raw',
                             'triplet-type-list.csv.gz')) and self.is_hetero

            has_necessary_file = has_necessary_file_simple or has_necessary_file_hetero

            if not has_necessary_file:
                url = self.meta_info['url']
                if decide_download(url):
                    path = download_url(url, self.original_root)
                    extract_zip(path, self.original_root)
                    os.unlink(path)
                    # delete folder if there exists
                    try:
                        shutil.rmtree(self.root)
                    except:
                        pass
                    shutil.move(
                        osp.join(self.original_root, self.download_name),
                        self.root)
                else:
                    print('Stop download.')
                    exit(-1)

            raw_dir = osp.join(self.root, 'raw')

            ### pre-process and save
            add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True'

            if self.meta_info['additional node files'] == 'None':
                additional_node_files = []
            else:
                additional_node_files = self.meta_info[
                    'additional node files'].split(',')

            if self.meta_info['additional edge files'] == 'None':
                additional_edge_files = []
            else:
                additional_edge_files = self.meta_info[
                    'additional edge files'].split(',')

            if self.is_hetero:
                if self.binary:
                    self.graph = read_binary_heterograph_raw(
                        raw_dir, add_inverse_edge=add_inverse_edge)[
                            0]  # only a single graph
                else:
                    self.graph = read_csv_heterograph_raw(
                        raw_dir,
                        add_inverse_edge=add_inverse_edge,
                        additional_node_files=additional_node_files,
                        additional_edge_files=additional_edge_files)[
                            0]  # only a single graph

            else:
                if self.binary:
                    self.graph = read_binary_graph_raw(
                        raw_dir, add_inverse_edge=add_inverse_edge)[
                            0]  # only a single graph
                else:
                    self.graph = read_csv_graph_raw(
                        raw_dir,
                        add_inverse_edge=add_inverse_edge,
                        additional_node_files=additional_node_files,
                        additional_edge_files=additional_edge_files)[
                            0]  # only a single graph

            print('Saving...')

            # torch.save(self.graph, pre_processed_file_path, pickle_protocol=4)
            dump_pickle(self.graph, pre_processed_file_path)
コード例 #5
0
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        pre_processed_file_path = osp.join(processed_dir, 'data_processed')

        if osp.exists(pre_processed_file_path):
            # loaded_dict = torch.load(pre_processed_file_path)
            loaded_dict = load_pickle(pre_processed_file_path)
            self.graph, self.labels = loaded_dict['graph'], loaded_dict[
                'labels']

        else:
            ### check download
            if self.binary:
                # npz format
                has_necessary_file_simple = osp.exists(
                    osp.join(self.root, 'raw',
                             'data.npz')) and (not self.is_hetero)
                has_necessary_file_hetero = osp.exists(
                    osp.join(self.root, 'raw',
                             'edge_index_dict.npz')) and self.is_hetero
            else:
                # csv file
                has_necessary_file_simple = osp.exists(
                    osp.join(self.root, 'raw',
                             'edge.csv.gz')) and (not self.is_hetero)
                has_necessary_file_hetero = osp.exists(
                    osp.join(self.root, 'raw',
                             'triplet-type-list.csv.gz')) and self.is_hetero

            has_necessary_file = has_necessary_file_simple or has_necessary_file_hetero

            if not has_necessary_file:
                url = self.meta_info['url']
                if decide_download(url):
                    path = download_url(url, self.original_root)
                    extract_zip(path, self.original_root)
                    os.unlink(path)
                    # delete folder if there exists
                    try:
                        shutil.rmtree(self.root)
                    except:
                        pass
                    shutil.move(
                        osp.join(self.original_root, self.download_name),
                        self.root)
                else:
                    print('Stop download.')
                    exit(-1)

            raw_dir = osp.join(self.root, 'raw')

            ### pre-process and save
            add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True'

            if self.meta_info['additional node files'] == 'None':
                additional_node_files = []
            else:
                additional_node_files = self.meta_info[
                    'additional node files'].split(',')

            if self.meta_info['additional edge files'] == 'None':
                additional_edge_files = []
            else:
                additional_edge_files = self.meta_info[
                    'additional edge files'].split(',')

            if self.is_hetero:
                if self.binary:
                    self.graph = read_binary_heterograph_raw(
                        raw_dir, add_inverse_edge=add_inverse_edge)[
                            0]  # only a single graph

                    tmp = np.load(osp.join(raw_dir, 'node-label.npz'))
                    self.labels = {}
                    for key in list(tmp.keys()):
                        self.labels[key] = tmp[key]
                    del tmp
                else:
                    self.graph = read_csv_heterograph_raw(
                        raw_dir,
                        add_inverse_edge=add_inverse_edge,
                        additional_node_files=additional_node_files,
                        additional_edge_files=additional_edge_files)[
                            0]  # only a single graph
                    self.labels = read_node_label_hetero(raw_dir)

            else:
                if self.binary:
                    self.graph = read_binary_graph_raw(
                        raw_dir, add_inverse_edge=add_inverse_edge)[
                            0]  # only a single graph
                    self.labels = np.load(osp.join(
                        raw_dir, 'node-label.npz'))['node_label']
                else:
                    self.graph = read_csv_graph_raw(
                        raw_dir,
                        add_inverse_edge=add_inverse_edge,
                        additional_node_files=additional_node_files,
                        additional_edge_files=additional_edge_files)[
                            0]  # only a single graph
                    self.labels = pd.read_csv(osp.join(raw_dir,
                                                       'node-label.csv.gz'),
                                              compression='gzip',
                                              header=None).values

            print('Saving...')
            # torch.save({'graph': self.graph, 'labels': self.labels}, pre_processed_file_path, pickle_protocol=4)
            dump_pickle({
                'graph': self.graph,
                'labels': self.labels
            }, pre_processed_file_path)
コード例 #6
0
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        raw_dir = osp.join(self.root, 'raw')
        pre_processed_file_path = osp.join(processed_dir, 'data_processed')

        if os.path.exists(pre_processed_file_path):
            loaded_dict = torch.load(pre_processed_file_path, 'rb')
            self.graphs, self.labels = loaded_dict['graphs'], loaded_dict[
                'labels']

        else:
            ### check download
            if self.binary:
                # npz format
                has_necessary_file = osp.exists(
                    osp.join(self.root, 'raw', 'data.npz'))
            else:
                # csv file
                has_necessary_file = osp.exists(
                    osp.join(self.root, 'raw', 'edge.csv.gz'))

            ### download
            if not has_necessary_file:
                url = self.meta_info['url']
                if decide_download(url):
                    path = download_url(url, self.original_root)
                    extract_zip(path, self.original_root)
                    os.unlink(path)
                    # delete folder if there exists
                    try:
                        shutil.rmtree(self.root)
                    except:
                        pass
                    shutil.move(
                        osp.join(self.original_root, self.download_name),
                        self.root)
                else:
                    print('Stop download.')
                    exit(-1)

            ### preprocess
            add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True'

            if self.meta_info['additional node files'] == 'None':
                additional_node_files = []
            else:
                additional_node_files = self.meta_info[
                    'additional node files'].split(',')

            if self.meta_info['additional edge files'] == 'None':
                additional_edge_files = []
            else:
                additional_edge_files = self.meta_info[
                    'additional edge files'].split(',')

            if self.binary:
                self.graphs = read_binary_graph_raw(
                    raw_dir, add_inverse_edge=add_inverse_edge)
            else:
                self.graphs = read_csv_graph_raw(
                    raw_dir,
                    add_inverse_edge=add_inverse_edge,
                    additional_node_files=additional_node_files,
                    additional_edge_files=additional_edge_files)

            if self.task_type == 'subtoken prediction':
                labels_joined = pd.read_csv(osp.join(raw_dir,
                                                     'graph-label.csv.gz'),
                                            compression='gzip',
                                            header=None).values
                # need to split each element into subtokens
                self.labels = [
                    str(labels_joined[i][0]).split(' ')
                    for i in range(len(labels_joined))
                ]
            else:
                if self.binary:
                    self.labels = np.load(osp.join(
                        raw_dir, 'graph-label.npz'))['graph_label']
                else:
                    self.labels = pd.read_csv(osp.join(raw_dir,
                                                       'graph-label.csv.gz'),
                                              compression='gzip',
                                              header=None).values

            print('Saving...')
            torch.save({
                'graphs': self.graphs,
                'labels': self.labels
            },
                       pre_processed_file_path,
                       pickle_protocol=4)
コード例 #7
0
    def _save_graph_list_homo(self, graph_list):
        dict_keys = graph_list[0].keys()
        # check necessary keys
        if not 'edge_index' in dict_keys:
            raise RuntimeError(
                'edge_index needs to be provided in graph objects')
        if not 'num_nodes' in dict_keys:
            raise RuntimeError(
                'num_nodes needs to be provided in graph objects')

        print(dict_keys)

        data_dict = {}
        # Store the following keys
        # - edge_index (necessary)
        # - num_nodes_list (necessary)
        # - num_edges_list (necessary)
        # - node_** (optional, node_feat is the default node features)
        # - edge_** (optional, edge_feat is the default edge features)

        # saving num_nodes_list
        num_nodes_list = np.array([graph['num_nodes']
                                   for graph in graph_list]).astype(np.int64)
        data_dict['num_nodes_list'] = num_nodes_list

        # saving edge_index and num_edges_list
        print('Saving edge_index')
        edge_index = np.concatenate(
            [graph['edge_index'] for graph in graph_list],
            axis=1).astype(np.int64)
        num_edges_list = np.array([
            graph['edge_index'].shape[1] for graph in graph_list
        ]).astype(np.int64)

        if edge_index.shape[0] != 2:
            raise RuntimeError('edge_index must have shape (2, num_edges)')

        data_dict['edge_index'] = edge_index
        data_dict['num_edges_list'] = num_edges_list

        for key in dict_keys:
            if key == 'edge_index' or key == 'num_nodes':
                continue
            if graph_list[0][key] is None:
                continue

            if 'node_' == key[:5]:
                # make sure saved in np.int64 or np.float32
                dtype = np.int64 if 'int' in str(
                    graph_list[0][key].dtype) else np.float32
                # check num_nodes
                for i in range(len(graph_list)):
                    if len(graph_list[i][key]) != num_nodes_list[i]:
                        raise RuntimeError(f'num_nodes mistmatches with {key}')

                cat_feat = np.concatenate([graph[key] for graph in graph_list],
                                          axis=0).astype(dtype)
                data_dict[key] = cat_feat

            elif 'edge_' == key[:5]:
                # make sure saved in np.int64 or np.float32
                dtype = np.int64 if 'int' in str(
                    graph_list[0][key].dtype) else np.float32
                # check num_edges
                for i in range(len(graph_list)):
                    if len(graph_list[i][key]) != num_edges_list[i]:
                        raise RuntimeError(f'num_edges mistmatches with {key}')

                cat_feat = np.concatenate([graph[key] for graph in graph_list],
                                          axis=0).astype(dtype)
                data_dict[key] = cat_feat

            else:
                raise RuntimeError(
                    f'Keys in graph object should start from either \'node_\' or \'edge_\', but \'{key}\' given.'
                )

        print('Saving all the files!')
        np.savez_compressed(osp.join(self.raw_dir, 'data.npz'), **data_dict)
        print('Validating...')
        # testing
        print('Reading saved files')
        graph_list_read = read_binary_graph_raw(self.raw_dir, False)

        print('Checking read graphs and given graphs are the same')
        for i in tqdm(range(len(graph_list))):
            # assert(graph_list[i].keys() == graph_list_read[i].keys())
            for key in graph_list[i].keys():
                if graph_list[i][key] is not None:
                    if isinstance(graph_list[i][key], np.ndarray):
                        assert (np.allclose(graph_list[i][key],
                                            graph_list_read[i][key],
                                            rtol=1e-4,
                                            atol=1e-4,
                                            equal_nan=True))
                    else:
                        assert (graph_list[i][key] == graph_list_read[i][key])

        del graph_list_read