示例#1
0
    def process(self):
        add_inverse_edge = self.meta_info[
            self.name]["add_inverse_edge"] == "True"

        if self.meta_info[self.name]["additional node files"] == 'None':
            additional_node_files = []
        else:
            additional_node_files = self.meta_info[
                self.name]["additional node files"].split(',')

        if self.meta_info[self.name]["additional edge files"] == 'None':
            additional_edge_files = []
        else:
            additional_edge_files = self.meta_info[
                self.name]["additional edge files"].split(',')

        if self.is_hetero:
            data = read_csv_heterograph_pyg(
                self.raw_dir,
                add_inverse_edge=add_inverse_edge,
                additional_node_files=additional_node_files,
                additional_edge_files=additional_edge_files)[0]

            node_label_dict = read_node_label_hetero(self.raw_dir)

            data.y_dict = {}
            if "classification" in self.task_type:
                for nodetype, node_label in node_label_dict.items():
                    data.y_dict[nodetype] = torch.from_numpy(node_label).to(
                        torch.long)
            else:
                for nodetype, node_label in node_label_dict.items():
                    data.y_dict[nodetype] = torch.from_numpy(node_label).to(
                        torch.float32)

        else:
            data = read_csv_graph_pyg(
                self.raw_dir,
                add_inverse_edge=add_inverse_edge,
                additional_node_files=additional_node_files,
                additional_edge_files=additional_edge_files)[0]

            ### adding prediction target
            node_label = pd.read_csv(osp.join(self.raw_dir,
                                              'node-label.csv.gz'),
                                     compression="gzip",
                                     header=None).values

            if "classification" in self.task_type:
                data.y = torch.from_numpy(node_label).to(torch.long)
            else:
                data.y = torch.from_numpy(node_label).to(torch.float32)

        data if self.pre_transform is None else self.pre_transform(data)

        print('Saving...')
        torch.save(self.collate([data]), self.processed_paths[0])
示例#2
0
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        pre_processed_file_path = osp.join(processed_dir, 'data_processed')

        if osp.exists(pre_processed_file_path):
            # loaded_dict = torch.load(pre_processed_file_path)
            loaded_dict = load_pickle(pre_processed_file_path)
            self.graph, self.labels = loaded_dict['graph'], loaded_dict[
                'labels']

        else:
            ### check download
            if self.binary:
                # npz format
                has_necessary_file_simple = osp.exists(
                    osp.join(self.root, 'raw',
                             'data.npz')) and (not self.is_hetero)
                has_necessary_file_hetero = osp.exists(
                    osp.join(self.root, 'raw',
                             'edge_index_dict.npz')) and self.is_hetero
            else:
                # csv file
                has_necessary_file_simple = osp.exists(
                    osp.join(self.root, 'raw',
                             'edge.csv.gz')) and (not self.is_hetero)
                has_necessary_file_hetero = osp.exists(
                    osp.join(self.root, 'raw',
                             'triplet-type-list.csv.gz')) and self.is_hetero

            has_necessary_file = has_necessary_file_simple or has_necessary_file_hetero

            if not has_necessary_file:
                url = self.meta_info['url']
                if decide_download(url):
                    path = download_url(url, self.original_root)
                    extract_zip(path, self.original_root)
                    os.unlink(path)
                    # delete folder if there exists
                    try:
                        shutil.rmtree(self.root)
                    except:
                        pass
                    shutil.move(
                        osp.join(self.original_root, self.download_name),
                        self.root)
                else:
                    print('Stop download.')
                    exit(-1)

            raw_dir = osp.join(self.root, 'raw')

            ### pre-process and save
            add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True'

            if self.meta_info['additional node files'] == 'None':
                additional_node_files = []
            else:
                additional_node_files = self.meta_info[
                    'additional node files'].split(',')

            if self.meta_info['additional edge files'] == 'None':
                additional_edge_files = []
            else:
                additional_edge_files = self.meta_info[
                    'additional edge files'].split(',')

            if self.is_hetero:
                if self.binary:
                    self.graph = read_binary_heterograph_raw(
                        raw_dir, add_inverse_edge=add_inverse_edge)[
                            0]  # only a single graph

                    tmp = np.load(osp.join(raw_dir, 'node-label.npz'))
                    self.labels = {}
                    for key in list(tmp.keys()):
                        self.labels[key] = tmp[key]
                    del tmp
                else:
                    self.graph = read_csv_heterograph_raw(
                        raw_dir,
                        add_inverse_edge=add_inverse_edge,
                        additional_node_files=additional_node_files,
                        additional_edge_files=additional_edge_files)[
                            0]  # only a single graph
                    self.labels = read_node_label_hetero(raw_dir)

            else:
                if self.binary:
                    self.graph = read_binary_graph_raw(
                        raw_dir, add_inverse_edge=add_inverse_edge)[
                            0]  # only a single graph
                    self.labels = np.load(osp.join(
                        raw_dir, 'node-label.npz'))['node_label']
                else:
                    self.graph = read_csv_graph_raw(
                        raw_dir,
                        add_inverse_edge=add_inverse_edge,
                        additional_node_files=additional_node_files,
                        additional_edge_files=additional_edge_files)[
                            0]  # only a single graph
                    self.labels = pd.read_csv(osp.join(raw_dir,
                                                       'node-label.csv.gz'),
                                              compression='gzip',
                                              header=None).values

            print('Saving...')
            # torch.save({'graph': self.graph, 'labels': self.labels}, pre_processed_file_path, pickle_protocol=4)
            dump_pickle({
                'graph': self.graph,
                'labels': self.labels
            }, pre_processed_file_path)
示例#3
0
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        pre_processed_file_path = osp.join(processed_dir, 'dgl_data_processed')

        if osp.exists(pre_processed_file_path):

            if not self.is_hetero:
                self.graph, label_dict = load_graphs(pre_processed_file_path)
                self.labels = label_dict['labels']
            else:
                with open(pre_processed_file_path, 'rb') as f:
                    self.graph, self.labels = pickle.load(f)

        else:
            ### check if the downloaded file exists
            has_necessary_file_simple = osp.exists(
                osp.join(self.root, "raw",
                         "edge.csv.gz")) and (not self.is_hetero)
            has_necessary_file_hetero = osp.exists(
                osp.join(self.root, "raw",
                         "triplet-type-list.csv.gz")) and self.is_hetero

            has_necessary_file = has_necessary_file_simple or has_necessary_file_hetero
            if not has_necessary_file:
                url = self.meta_info[self.name]["url"]
                if decide_download(url):
                    path = download_url(url, self.original_root)
                    extract_zip(path, self.original_root)
                    os.unlink(path)
                    # delete folder if there exists
                    try:
                        shutil.rmtree(self.root)
                    except:
                        pass
                    shutil.move(
                        osp.join(self.original_root, self.download_name),
                        self.root)
                else:
                    print("Stop download.")
                    exit(-1)

            raw_dir = osp.join(self.root, "raw")

            ### pre-process and save
            add_inverse_edge = self.meta_info[
                self.name]["add_inverse_edge"] == "True"

            if self.meta_info[self.name]["additional node files"] == 'None':
                additional_node_files = []
            else:
                additional_node_files = self.meta_info[
                    self.name]["additional node files"].split(',')

            if self.meta_info[self.name]["additional edge files"] == 'None':
                additional_edge_files = []
            else:
                additional_edge_files = self.meta_info[
                    self.name]["additional edge files"].split(',')

            if self.is_hetero:
                graph = read_csv_heterograph_dgl(
                    raw_dir,
                    add_inverse_edge=add_inverse_edge,
                    additional_node_files=additional_node_files,
                    additional_edge_files=additional_edge_files)[0]

                label_dict = read_node_label_hetero(raw_dir)

                # convert into torch tensor
                if "classification" in self.task_type:
                    for nodetype in label_dict.keys():
                        # detect if there is any nan
                        node_label = label_dict[nodetype]
                        if np.isnan(node_label).any():
                            label_dict[nodetype] = torch.from_numpy(
                                node_label).to(torch.float32)
                        else:
                            label_dict[nodetype] = torch.from_numpy(
                                node_label).to(torch.long)
                else:
                    for nodetype in label_dict.keys():
                        node_label = label_dict[nodetype]
                        label_dict[nodetype] = torch.from_numpy(node_label).to(
                            torch.float32)

                with open(pre_processed_file_path, 'wb') as f:
                    pickle.dump(([graph], label_dict), f)

                with open(pre_processed_file_path, 'rb') as f:
                    self.graph, self.labels = pickle.load(f)

            else:
                graph = read_csv_graph_dgl(
                    raw_dir,
                    add_inverse_edge=add_inverse_edge,
                    additional_node_files=additional_node_files,
                    additional_edge_files=additional_edge_files)[0]

                ### adding prediction target
                node_label = pd.read_csv(osp.join(raw_dir,
                                                  'node-label.csv.gz'),
                                         compression="gzip",
                                         header=None).values

                if "classification" in self.task_type:
                    # detect if there is any nan
                    if np.isnan(node_label).any():
                        node_label = torch.from_numpy(node_label).to(
                            torch.float32)
                    else:
                        node_label = torch.from_numpy(node_label).to(
                            torch.long)
                else:
                    node_label = torch.from_numpy(node_label).to(torch.float32)

                label_dict = {"labels": node_label}

                save_graphs(pre_processed_file_path, graph, label_dict)

                self.graph, label_dict = load_graphs(pre_processed_file_path)
                self.labels = label_dict['labels']
示例#4
0
    def process(self):
        add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True'

        if self.meta_info['additional node files'] == 'None':
            additional_node_files = []
        else:
            additional_node_files = self.meta_info[
                'additional node files'].split(',')

        if self.meta_info['additional edge files'] == 'None':
            additional_edge_files = []
        else:
            additional_edge_files = self.meta_info[
                'additional edge files'].split(',')

        if self.is_hetero:
            data = read_heterograph_pyg(
                self.raw_dir,
                add_inverse_edge=add_inverse_edge,
                additional_node_files=additional_node_files,
                additional_edge_files=additional_edge_files,
                binary=self.binary)[0]

            if self.binary:
                tmp = np.load(osp.join(self.raw_dir, 'node-label.npz'))
                node_label_dict = {}
                for key in list(tmp.keys()):
                    node_label_dict[key] = tmp[key]
                del tmp
            else:
                node_label_dict = read_node_label_hetero(self.raw_dir)

            data.y_dict = {}
            if 'classification' in self.task_type:
                for nodetype, node_label in node_label_dict.items():
                    # detect if there is any nan
                    if np.isnan(node_label).any():
                        data.y_dict[nodetype] = torch.from_numpy(
                            node_label).to(torch.float32)
                    else:
                        data.y_dict[nodetype] = torch.from_numpy(
                            node_label).to(torch.long)
            else:
                for nodetype, node_label in node_label_dict.items():
                    data.y_dict[nodetype] = torch.from_numpy(node_label).to(
                        torch.float32)

        else:
            data = read_graph_pyg(self.raw_dir,
                                  add_inverse_edge=add_inverse_edge,
                                  additional_node_files=additional_node_files,
                                  additional_edge_files=additional_edge_files,
                                  binary=self.binary)[0]

            ### adding prediction target
            if self.binary:
                node_label = np.load(osp.join(self.raw_dir,
                                              'node-label.npz'))['node_label']
            else:
                node_label = pd.read_csv(osp.join(self.raw_dir,
                                                  'node-label.csv.gz'),
                                         compression='gzip',
                                         header=None).values

            if 'classification' in self.task_type:
                # detect if there is any nan
                if np.isnan(node_label).any():
                    data.y = torch.from_numpy(node_label).to(torch.float32)
                else:
                    data.y = torch.from_numpy(node_label).to(torch.long)

            else:
                data.y = torch.from_numpy(node_label).to(torch.float32)

        data = data if self.pre_transform is None else self.pre_transform(data)

        print('Saving...')
        torch.save(self.collate([data]), self.processed_paths[0])
示例#5
0
文件: dataset.py 项目: yzh119/ogb
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        pre_processed_file_path = osp.join(processed_dir, 'data_processed')

        if osp.exists(pre_processed_file_path):
            loaded_dict = torch.load(pre_processed_file_path)
            self.graph, self.labels = loaded_dict['graph'], loaded_dict[
                'labels']

        else:
            ### check download
            has_necessary_file_simple = osp.exists(
                osp.join(self.root, "raw",
                         "edge.csv.gz")) and (not self.is_hetero)
            has_necessary_file_hetero = osp.exists(
                osp.join(self.root, "raw",
                         "triplet-type-list.csv.gz")) and self.is_hetero

            has_necessary_file = has_necessary_file_simple or has_necessary_file_hetero

            if not has_necessary_file:
                url = self.meta_info[self.name]["url"]
                if decide_download(url):
                    path = download_url(url, self.original_root)
                    extract_zip(path, self.original_root)
                    os.unlink(path)
                    # delete folder if there exists
                    try:
                        shutil.rmtree(self.root)
                    except:
                        pass
                    shutil.move(
                        osp.join(self.original_root, self.download_name),
                        self.root)
                else:
                    print("Stop download.")
                    exit(-1)

            raw_dir = osp.join(self.root, "raw")

            ### pre-process and save
            add_inverse_edge = self.meta_info[
                self.name]["add_inverse_edge"] == "True"

            if self.meta_info[self.name]["additional node files"] == 'None':
                additional_node_files = []
            else:
                additional_node_files = self.meta_info[
                    self.name]["additional node files"].split(',')

            if self.meta_info[self.name]["additional edge files"] == 'None':
                additional_edge_files = []
            else:
                additional_edge_files = self.meta_info[
                    self.name]["additional edge files"].split(',')

            if self.is_hetero:
                self.graph = read_csv_heterograph_raw(
                    raw_dir,
                    add_inverse_edge=add_inverse_edge,
                    additional_node_files=additional_node_files,
                    additional_edge_files=additional_edge_files)[
                        0]  # only a single graph
                self.labels = read_node_label_hetero(raw_dir)

            else:
                self.graph = read_csv_graph_raw(
                    raw_dir,
                    add_inverse_edge=add_inverse_edge,
                    additional_node_files=additional_node_files,
                    additional_edge_files=additional_edge_files)[
                        0]  # only a single graph

                ### adding prediction target
                self.labels = pd.read_csv(osp.join(raw_dir,
                                                   'node-label.csv.gz'),
                                          compression="gzip",
                                          header=None).values

            print('Saving...')
            torch.save({
                'graph': self.graph,
                'labels': self.labels
            },
                       pre_processed_file_path,
                       pickle_protocol=4)
示例#6
0
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        pre_processed_file_path = osp.join(processed_dir, 'dgl_data_processed')

        if osp.exists(pre_processed_file_path):
            self.graph, label_dict = load_graphs(pre_processed_file_path)
            if self.is_hetero:
                self.labels = label_dict
            else:
                self.labels = label_dict['labels']


        else:
            ### check if the downloaded file exists
            if self.binary:
                # npz format
                has_necessary_file_simple = osp.exists(osp.join(self.root, 'raw', 'data.npz')) and (not self.is_hetero)
                has_necessary_file_hetero = osp.exists(osp.join(self.root, 'raw', 'edge_index_dict.npz')) and self.is_hetero
            else:
                # csv file
                has_necessary_file_simple = osp.exists(osp.join(self.root, 'raw', 'edge.csv.gz')) and (not self.is_hetero)
                has_necessary_file_hetero = osp.exists(osp.join(self.root, 'raw', 'triplet-type-list.csv.gz')) and self.is_hetero
            
            has_necessary_file = has_necessary_file_simple or has_necessary_file_hetero

            if not has_necessary_file:
                url = self.meta_info['url']
                if decide_download(url):
                    path = download_url(url, self.original_root)
                    extract_zip(path, self.original_root)
                    os.unlink(path)
                    # delete folder if there exists
                    try:
                        shutil.rmtree(self.root)
                    except:
                        pass
                    shutil.move(osp.join(self.original_root, self.download_name), self.root)
                else:
                    print('Stop download.')
                    exit(-1)

            raw_dir = osp.join(self.root, 'raw')

            ### pre-process and save
            add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True'

            if self.meta_info['additional node files'] == 'None':
                additional_node_files = []
            else:
                additional_node_files = self.meta_info['additional node files'].split(',')

            if self.meta_info['additional edge files'] == 'None':
                additional_edge_files = []
            else:
                additional_edge_files = self.meta_info['additional edge files'].split(',')


            if self.is_hetero:
                graph = read_heterograph_dgl(raw_dir, add_inverse_edge = add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files, binary=self.binary)[0]
                
                if self.binary:
                    tmp = np.load(osp.join(raw_dir, 'node-label.npz'))
                    label_dict = {}
                    for key in list(tmp.keys()):
                        label_dict[key] = tmp[key]
                    del tmp
                else:
                    label_dict = read_node_label_hetero(raw_dir)

                # convert into torch tensor
                if 'classification' in self.task_type:
                    for nodetype in label_dict.keys():
                        # detect if there is any nan
                        node_label = label_dict[nodetype]
                        if np.isnan(node_label).any():
                            label_dict[nodetype] = torch.from_numpy(node_label).to(torch.float32)
                        else:
                            label_dict[nodetype] = torch.from_numpy(node_label).to(torch.long)
                else:
                    for nodetype in label_dict.keys():
                        node_label = label_dict[nodetype]
                        label_dict[nodetype] = torch.from_numpy(node_label).to(torch.float32)

            else:
                graph = read_graph_dgl(raw_dir, add_inverse_edge = add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files, binary=self.binary)[0]

                ### adding prediction target
                if self.binary:
                    node_label = np.load(osp.join(raw_dir, 'node-label.npz'))['node_label']
                else:
                    node_label = pd.read_csv(osp.join(raw_dir, 'node-label.csv.gz'), compression='gzip', header = None).values

                if 'classification' in self.task_type:
                    # detect if there is any nan
                    if np.isnan(node_label).any():
                        node_label = torch.from_numpy(node_label).to(torch.float32)
                    else:
                        node_label = torch.from_numpy(node_label).to(torch.long)
                else:
                    node_label = torch.from_numpy(node_label).to(torch.float32)

                label_dict = {'labels': node_label}

            print('Saving...')
            save_graphs(pre_processed_file_path, graph, label_dict)

            self.graph, label_dict = load_graphs(pre_processed_file_path)

            if self.is_hetero:
                self.labels = label_dict
            else:
                self.labels = label_dict['labels']