示例#1
0
文件: dataset.py 项目: yzh119/ogb
    def get_idx_split(self, split_type=None):
        if split_type is None:
            split_type = self.meta_info[self.name]["split"]

        path = osp.join(self.root, "split", split_type)

        if self.is_hetero:
            train_idx_dict, valid_idx_dict, test_idx_dict = read_nodesplitidx_split_hetero(
                path)
            for nodetype in train_idx_dict.keys():
                train_idx_dict[nodetype] = train_idx_dict[nodetype]
                valid_idx_dict[nodetype] = valid_idx_dict[nodetype]
                test_idx_dict[nodetype] = test_idx_dict[nodetype]

                return {
                    "train": train_idx_dict,
                    "valid": valid_idx_dict,
                    "test": test_idx_dict
                }

        else:
            train_idx = pd.read_csv(osp.join(path, "train.csv.gz"),
                                    compression="gzip",
                                    header=None).values.T[0]
            valid_idx = pd.read_csv(osp.join(path, "valid.csv.gz"),
                                    compression="gzip",
                                    header=None).values.T[0]
            test_idx = pd.read_csv(osp.join(path, "test.csv.gz"),
                                   compression="gzip",
                                   header=None).values.T[0]

            return {"train": train_idx, "valid": valid_idx, "test": test_idx}
示例#2
0
    def get_idx_split(self, split_type = None):
        if split_type is None:
            split_type = self.meta_info['split']

        path = osp.join(self.root, 'split', split_type)

        # short-cut if split_dict.pt exists
        if os.path.isfile(os.path.join(path, 'split_dict.pt')):
            return torch.load(os.path.join(path, 'split_dict.pt'))

        if self.is_hetero:
            train_idx_dict, valid_idx_dict, test_idx_dict = read_nodesplitidx_split_hetero(path)
            for nodetype in train_idx_dict.keys():
                train_idx_dict[nodetype] = torch.from_numpy(train_idx_dict[nodetype]).to(torch.long)
                valid_idx_dict[nodetype] = torch.from_numpy(valid_idx_dict[nodetype]).to(torch.long)
                test_idx_dict[nodetype] = torch.from_numpy(test_idx_dict[nodetype]).to(torch.long)

                return {'train': train_idx_dict, 'valid': valid_idx_dict, 'test': test_idx_dict}

        else:
            train_idx = torch.from_numpy(pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header = None).values.T[0]).to(torch.long)
            valid_idx = torch.from_numpy(pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header = None).values.T[0]).to(torch.long)
            test_idx = torch.from_numpy(pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header = None).values.T[0]).to(torch.long)

            return {'train': train_idx, 'valid': valid_idx, 'test': test_idx}