def read_planetoid_data(folder, prefix): prefix = prefix.lower() names = ["x", "tx", "allx", "y", "ty", "ally", "graph", "test.index"] objects = [] for item in names[:-1]: with open(f"{folder}/ind.{prefix}.{item}", "rb") as f: if sys.version_info > (3, 0): objects.append(pkl.load(f, encoding="latin1")) else: objects.append(pkl.load(f)) test_index = parse_index_file(f"{folder}/ind.{prefix}.{names[-1]}") test_index = torch.Tensor(test_index).long() test_index_reorder = test_index.sort()[0] x, tx, allx, y, ty, ally, graph = tuple(objects) x, tx, allx = tuple( [torch.from_numpy(item.todense()).float() for item in [x, tx, allx]]) y, ty, ally = tuple( [torch.from_numpy(item).float() for item in [y, ty, ally]]) train_index = torch.arange(y.size(0), dtype=torch.long) val_index = torch.arange(y.size(0), y.size(0) + 500, dtype=torch.long) if prefix.lower() == "citeseer": # There are some isolated nodes in the Citeseer graph, resulting in # none consecutive test indices. We need to identify them and add them # as zero vectors to `tx` and `ty`. len_test_indices = (test_index.max() - test_index.min()).item() + 1 tx_ext = torch.zeros(len_test_indices, tx.size(1)) tx_ext[test_index_reorder - test_index.min(), :] = tx ty_ext = torch.zeros(len_test_indices, ty.size(1)) ty_ext[test_index_reorder - test_index.min(), :] = ty tx, ty = tx_ext, ty_ext x = torch.cat([allx, tx], dim=0).float() y = torch.cat([ally, ty], dim=0).max(dim=1)[1].long() x[test_index] = x[test_index_reorder] y[test_index] = y[test_index_reorder] train_mask = index_to_mask(train_index, size=y.size(0)) val_mask = index_to_mask(val_index, size=y.size(0)) test_mask = index_to_mask(test_index, size=y.size(0)) edge_index = edge_index_from_dict(graph, num_nodes=y.size(0)) data = Graph(x=x, edge_index=edge_index, y=y) data.train_mask = train_mask data.val_mask = val_mask data.test_mask = test_mask return data
def read_triplet_data(folder): filenames = ["train2id.txt", "valid2id.txt", "test2id.txt"] count = 0 edge_index = [] edge_attr = [] count_list = [] triples = [] num_entities = 0 num_relations = 0 entity_dic = {} relation_dic = {} for filename in filenames: with open(osp.join(folder, filename), "r") as f: _ = int(f.readline().strip()) if "train" in filename: train_start_idx = len(triples) elif "valid" in filename: valid_start_idx = len(triples) elif "test" in filename: test_start_idx = len(triples) for line in f: items = line.strip().split() edge_index.append([int(items[0]), int(items[1])]) edge_attr.append(int(items[2])) triples.append((int(items[0]), int(items[2]), int(items[1]))) if items[0] not in entity_dic: entity_dic[items[0]] = num_entities num_entities += 1 if items[1] not in entity_dic: entity_dic[items[1]] = num_entities num_entities += 1 if items[2] not in relation_dic: relation_dic[items[2]] = num_relations num_relations += 1 count += 1 count_list.append(count) edge_index = torch.LongTensor(edge_index).t() edge_attr = torch.LongTensor(edge_attr) data = Graph() data.edge_index = edge_index data.edge_attr = edge_attr def generate_mask(start, end): mask = torch.BoolTensor(count) mask[:] = False mask[start:end] = True return mask data.train_mask = generate_mask(0, count_list[0]) data.val_mask = generate_mask(count_list[0], count_list[1]) data.test_mask = generate_mask(count_list[1], count_list[2]) return data, triples, train_start_idx, valid_start_idx, test_start_idx, num_entities, num_relations