Beispiel #1
0
def read_heterograph_dgl(raw_dir,
                         add_inverse_edge=False,
                         additional_node_files=[],
                         additional_edge_files=[],
                         binary=False):

    if binary:
        # npz
        graph_list = read_binary_heterograph_raw(raw_dir, add_inverse_edge)
    else:
        # csv
        graph_list = read_csv_heterograph_raw(
            raw_dir,
            add_inverse_edge,
            additional_node_files=additional_node_files,
            additional_edge_files=additional_edge_files)

    dgl_graph_list = []

    print('Converting graphs into DGL objects...')

    for graph in tqdm(graph_list):
        g_dict = {}

        # add edge connectivity
        for triplet, edge_index in graph['edge_index_dict'].items():
            edge_tuple = [(i, j)
                          for i, j in zip(graph['edge_index_dict'][triplet][0],
                                          graph['edge_index_dict'][triplet][1])
                          ]
            g_dict[triplet] = edge_tuple

        dgl_hetero_graph = dgl.heterograph(
            g_dict, num_nodes_dict=graph['num_nodes_dict'])

        if graph['edge_feat_dict'] is not None:
            for triplet in graph['edge_feat_dict'].keys():
                dgl_hetero_graph.edges[triplet].data[
                    'feat'] = torch.from_numpy(
                        graph['edge_feat_dict'][triplet])

        if graph['node_feat_dict'] is not None:
            for nodetype in graph['node_feat_dict'].keys():
                dgl_hetero_graph.nodes[nodetype].data[
                    'feat'] = torch.from_numpy(
                        graph['node_feat_dict'][nodetype])

        for key in additional_node_files:
            for nodetype in graph[key].keys():
                dgl_hetero_graph.nodes[nodetype].data[
                    key[5:]] = torch.from_numpy(graph[key][nodetype])

        for key in additional_edge_files:
            for triplet in graph[key].keys():
                dgl_hetero_graph.edges[triplet].data[
                    key[5:]] = torch.from_numpy(graph[key][triplet])

        dgl_graph_list.append(dgl_hetero_graph)

    return dgl_graph_list
Beispiel #2
0
def read_heterograph_pyg(raw_dir,
                         add_inverse_edge=False,
                         additional_node_files=[],
                         additional_edge_files=[],
                         binary=False):

    if binary:
        # npz
        graph_list = read_binary_heterograph_raw(raw_dir, add_inverse_edge)
    else:
        # csv
        graph_list = read_csv_heterograph_raw(
            raw_dir,
            add_inverse_edge,
            additional_node_files=additional_node_files,
            additional_edge_files=additional_edge_files)

    pyg_graph_list = []

    print('Converting graphs into PyG objects...')

    for graph in tqdm(graph_list):
        g = Data()

        g.__num_nodes__ = graph['num_nodes_dict']
        g.num_nodes_dict = graph['num_nodes_dict']

        # add edge connectivity
        g.edge_index_dict = {}
        for triplet, edge_index in graph['edge_index_dict'].items():
            g.edge_index_dict[triplet] = torch.from_numpy(edge_index)

        del graph['edge_index_dict']

        if graph['edge_feat_dict'] is not None:
            g.edge_attr_dict = {}
            for triplet in graph['edge_feat_dict'].keys():
                g.edge_attr_dict[triplet] = torch.from_numpy(
                    graph['edge_feat_dict'][triplet])

            del graph['edge_feat_dict']

        if graph['node_feat_dict'] is not None:
            g.x_dict = {}
            for nodetype in graph['node_feat_dict'].keys():
                g.x_dict[nodetype] = torch.from_numpy(
                    graph['node_feat_dict'][nodetype])

            del graph['node_feat_dict']

        for key in additional_node_files:
            g[key] = {}
            for nodetype in graph[key].keys():
                g[key][nodetype] = torch.from_numpy(graph[key][nodetype])

            del graph[key]

        for key in additional_edge_files:
            g[key] = {}
            for triplet in graph[key].keys():
                g[key][triplet] = torch.from_numpy(graph[key][triplet])

            del graph[key]

        pyg_graph_list.append(g)

    return pyg_graph_list
Beispiel #3
0
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        pre_processed_file_path = osp.join(processed_dir, 'data_processed')

        if osp.exists(pre_processed_file_path):
            # self.graph = torch.load(pre_processed_file_path, 'rb')
            self.graph = load_pickle(pre_processed_file_path)

        else:
            ### check download
            if self.binary:
                # npz format
                has_necessary_file_simple = osp.exists(
                    osp.join(self.root, 'raw',
                             'data.npz')) and (not self.is_hetero)
                has_necessary_file_hetero = osp.exists(
                    osp.join(self.root, 'raw',
                             'edge_index_dict.npz')) and self.is_hetero
            else:
                # csv file
                has_necessary_file_simple = osp.exists(
                    osp.join(self.root, 'raw',
                             'edge.csv.gz')) and (not self.is_hetero)
                has_necessary_file_hetero = osp.exists(
                    osp.join(self.root, 'raw',
                             'triplet-type-list.csv.gz')) and self.is_hetero

            has_necessary_file = has_necessary_file_simple or has_necessary_file_hetero

            if not has_necessary_file:
                url = self.meta_info['url']
                if decide_download(url):
                    path = download_url(url, self.original_root)
                    extract_zip(path, self.original_root)
                    os.unlink(path)
                    # delete folder if there exists
                    try:
                        shutil.rmtree(self.root)
                    except:
                        pass
                    shutil.move(
                        osp.join(self.original_root, self.download_name),
                        self.root)
                else:
                    print('Stop download.')
                    exit(-1)

            raw_dir = osp.join(self.root, 'raw')

            ### pre-process and save
            add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True'

            if self.meta_info['additional node files'] == 'None':
                additional_node_files = []
            else:
                additional_node_files = self.meta_info[
                    'additional node files'].split(',')

            if self.meta_info['additional edge files'] == 'None':
                additional_edge_files = []
            else:
                additional_edge_files = self.meta_info[
                    'additional edge files'].split(',')

            if self.is_hetero:
                if self.binary:
                    self.graph = read_binary_heterograph_raw(
                        raw_dir, add_inverse_edge=add_inverse_edge)[
                            0]  # only a single graph
                else:
                    self.graph = read_csv_heterograph_raw(
                        raw_dir,
                        add_inverse_edge=add_inverse_edge,
                        additional_node_files=additional_node_files,
                        additional_edge_files=additional_edge_files)[
                            0]  # only a single graph

            else:
                if self.binary:
                    self.graph = read_binary_graph_raw(
                        raw_dir, add_inverse_edge=add_inverse_edge)[
                            0]  # only a single graph
                else:
                    self.graph = read_csv_graph_raw(
                        raw_dir,
                        add_inverse_edge=add_inverse_edge,
                        additional_node_files=additional_node_files,
                        additional_edge_files=additional_edge_files)[
                            0]  # only a single graph

            print('Saving...')

            # torch.save(self.graph, pre_processed_file_path, pickle_protocol=4)
            dump_pickle(self.graph, pre_processed_file_path)
Beispiel #4
0
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        pre_processed_file_path = osp.join(processed_dir, 'data_processed')

        if osp.exists(pre_processed_file_path):
            # loaded_dict = torch.load(pre_processed_file_path)
            loaded_dict = load_pickle(pre_processed_file_path)
            self.graph, self.labels = loaded_dict['graph'], loaded_dict[
                'labels']

        else:
            ### check download
            if self.binary:
                # npz format
                has_necessary_file_simple = osp.exists(
                    osp.join(self.root, 'raw',
                             'data.npz')) and (not self.is_hetero)
                has_necessary_file_hetero = osp.exists(
                    osp.join(self.root, 'raw',
                             'edge_index_dict.npz')) and self.is_hetero
            else:
                # csv file
                has_necessary_file_simple = osp.exists(
                    osp.join(self.root, 'raw',
                             'edge.csv.gz')) and (not self.is_hetero)
                has_necessary_file_hetero = osp.exists(
                    osp.join(self.root, 'raw',
                             'triplet-type-list.csv.gz')) and self.is_hetero

            has_necessary_file = has_necessary_file_simple or has_necessary_file_hetero

            if not has_necessary_file:
                url = self.meta_info['url']
                if decide_download(url):
                    path = download_url(url, self.original_root)
                    extract_zip(path, self.original_root)
                    os.unlink(path)
                    # delete folder if there exists
                    try:
                        shutil.rmtree(self.root)
                    except:
                        pass
                    shutil.move(
                        osp.join(self.original_root, self.download_name),
                        self.root)
                else:
                    print('Stop download.')
                    exit(-1)

            raw_dir = osp.join(self.root, 'raw')

            ### pre-process and save
            add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True'

            if self.meta_info['additional node files'] == 'None':
                additional_node_files = []
            else:
                additional_node_files = self.meta_info[
                    'additional node files'].split(',')

            if self.meta_info['additional edge files'] == 'None':
                additional_edge_files = []
            else:
                additional_edge_files = self.meta_info[
                    'additional edge files'].split(',')

            if self.is_hetero:
                if self.binary:
                    self.graph = read_binary_heterograph_raw(
                        raw_dir, add_inverse_edge=add_inverse_edge)[
                            0]  # only a single graph

                    tmp = np.load(osp.join(raw_dir, 'node-label.npz'))
                    self.labels = {}
                    for key in list(tmp.keys()):
                        self.labels[key] = tmp[key]
                    del tmp
                else:
                    self.graph = read_csv_heterograph_raw(
                        raw_dir,
                        add_inverse_edge=add_inverse_edge,
                        additional_node_files=additional_node_files,
                        additional_edge_files=additional_edge_files)[
                            0]  # only a single graph
                    self.labels = read_node_label_hetero(raw_dir)

            else:
                if self.binary:
                    self.graph = read_binary_graph_raw(
                        raw_dir, add_inverse_edge=add_inverse_edge)[
                            0]  # only a single graph
                    self.labels = np.load(osp.join(
                        raw_dir, 'node-label.npz'))['node_label']
                else:
                    self.graph = read_csv_graph_raw(
                        raw_dir,
                        add_inverse_edge=add_inverse_edge,
                        additional_node_files=additional_node_files,
                        additional_edge_files=additional_edge_files)[
                            0]  # only a single graph
                    self.labels = pd.read_csv(osp.join(raw_dir,
                                                       'node-label.csv.gz'),
                                              compression='gzip',
                                              header=None).values

            print('Saving...')
            # torch.save({'graph': self.graph, 'labels': self.labels}, pre_processed_file_path, pickle_protocol=4)
            dump_pickle({
                'graph': self.graph,
                'labels': self.labels
            }, pre_processed_file_path)
Beispiel #5
0
    def _save_graph_list_hetero(self, graph_list):
        dict_keys = graph_list[0].keys()
        # check necessary keys
        if not 'edge_index_dict' in dict_keys:
            raise RuntimeError(
                'edge_index_dict needs to be provided in graph objects')
        if not 'num_nodes_dict' in dict_keys:
            raise RuntimeError(
                'num_nodes_dict needs to be provided in graph objects')

        print(dict_keys)

        # Store the following files
        # - edge_index_dict.npz (necessary)
        #   edge_index_dict
        # - num_nodes_dict.npz (necessary)
        #   num_nodes_dict
        # - num_edges_dict.npz (necessary)
        #   num_edges_dict
        # - node_**.npz (optional, node_feat_dict is the default node features)
        # - edge_**.npz (optional, edge_feat_dict the default edge features)

        # extract entity types
        ent_type_list = sorted(
            [e for e in graph_list[0]['num_nodes_dict'].keys()])

        # saving num_nodes_dict
        print('Saving num_nodes_dict')
        num_nodes_dict = {}
        for ent_type in ent_type_list:
            num_nodes_dict[ent_type] = np.array([
                graph['num_nodes_dict'][ent_type] for graph in graph_list
            ]).astype(np.int64)
        np.savez_compressed(osp.join(self.raw_dir, 'num_nodes_dict.npz'),
                            **num_nodes_dict)

        print(num_nodes_dict)

        # extract triplet types
        triplet_type_list = sorted([
            (h, r, t) for (h, r, t) in graph_list[0]['edge_index_dict'].keys()
        ])
        print(triplet_type_list)

        # saving edge_index_dict
        print('Saving edge_index_dict')
        num_edges_dict = {}
        edge_index_dict = {}
        for triplet in triplet_type_list:
            # representing triplet (head, rel, tail) as a single string 'head___rel___tail'
            triplet_cat = '___'.join(triplet)
            edge_index = np.concatenate(
                [graph['edge_index_dict'][triplet] for graph in graph_list],
                axis=1).astype(np.int64)
            if edge_index.shape[0] != 2:
                raise RuntimeError('edge_index must have shape (2, num_edges)')

            num_edges = np.array([
                graph['edge_index_dict'][triplet].shape[1]
                for graph in graph_list
            ]).astype(np.int64)
            num_edges_dict[triplet_cat] = num_edges
            edge_index_dict[triplet_cat] = edge_index

        print(edge_index_dict)
        print(num_edges_dict)

        np.savez_compressed(osp.join(self.raw_dir, 'edge_index_dict.npz'),
                            **edge_index_dict)
        np.savez_compressed(osp.join(self.raw_dir, 'num_edges_dict.npz'),
                            **num_edges_dict)

        for key in dict_keys:
            if key == 'edge_index_dict' or key == 'num_nodes_dict':
                continue
            if graph_list[0][key] is None:
                continue

            print(f'Saving {key}')

            feat_dict = {}

            if 'node_' in key:
                # node feature dictionary
                for ent_type in graph_list[0][key].keys():
                    if ent_type not in num_nodes_dict:
                        raise RuntimeError(
                            f'Encountered unknown entity type called {ent_type}.'
                        )

                    # check num_nodes
                    for i in range(len(graph_list)):
                        if len(graph_list[i][key]
                               [ent_type]) != num_nodes_dict[ent_type][i]:
                            raise RuntimeError(
                                f'num_nodes mistmatches with {key}[{ent_type}]'
                            )

                    # make sure saved in np.int64 or np.float32
                    dtype = np.int64 if 'int' in str(
                        graph_list[0][key][ent_type].dtype) else np.float32
                    cat_feat = np.concatenate(
                        [graph[key][ent_type] for graph in graph_list],
                        axis=0).astype(dtype)
                    feat_dict[ent_type] = cat_feat

            elif 'edge_' in key:
                # edge feature dictionary
                for triplet in graph_list[0][key].keys():
                    # representing triplet (head, rel, tail) as a single string 'head___rel___tail'
                    triplet_cat = '___'.join(triplet)
                    if triplet_cat not in num_edges_dict:
                        raise RuntimeError(
                            f"Encountered unknown triplet type called ({','.join(triplet)})."
                        )

                    # check num_edges
                    for i in range(len(graph_list)):
                        if len(graph_list[i][key]
                               [triplet]) != num_edges_dict[triplet_cat][i]:
                            raise RuntimeError(
                                f"num_edges mismatches with {key}[({','.join(triplet)})]"
                            )

                    # make sure saved in np.int64 or np.float32
                    dtype = np.int64 if 'int' in str(
                        graph_list[0][key][triplet].dtype) else np.float32
                    cat_feat = np.concatenate(
                        [graph[key][triplet] for graph in graph_list],
                        axis=0).astype(dtype)
                    feat_dict[triplet_cat] = cat_feat

            else:
                raise RuntimeError(
                    f'Keys in graph object should start from either \'node_\' or \'edge_\', but \'{key}\' given.'
                )

            np.savez_compressed(osp.join(self.raw_dir, f'{key}.npz'),
                                **feat_dict)

        print('Validating...')
        # testing
        print('Reading saved files')
        graph_list_read = read_binary_heterograph_raw(self.raw_dir, False)

        print('Checking read graphs and given graphs are the same')
        for i in tqdm(range(len(graph_list))):
            for key0, value0 in graph_list[i].items():
                if value0 is not None:
                    for key1, value1 in value0.items():
                        if isinstance(graph_list[i][key0][key1], np.ndarray):
                            assert (np.allclose(graph_list[i][key0][key1],
                                                graph_list_read[i][key0][key1],
                                                rtol=1e-04,
                                                atol=1e-04,
                                                equal_nan=True))
                        else:
                            assert (graph_list[i][key0][key1] ==
                                    graph_list_read[i][key0][key1])

        del graph_list_read