def __init__(self, args=None): dataset = "jknet_cora" path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) if not osp.exists(path): os.makedirs(path) super(CoraDataset, self).__init__(path) with open(self.processed_paths[0], 'rb') as fin: load_data = pickle.load(fin) self.num_nodes = load_data['node_num'] data = Data() data.x = load_data['xs'] data.y = load_data['ys'] train_size = int(self.num_nodes * 0.8) train_mask = np.zeros((self.num_nodes, ), dtype=bool) train_idx = np.random.choice(np.arange(self.num_nodes), size=train_size, replace=False) train_mask[train_idx] = True test_mask = np.ones((self.num_nodes, ), dtype=bool) test_mask[train_idx] = False val_mask = test_mask edges = load_data['edges'] edges = np.array(edges, dtype=int).transpose((1, 0)) data.edge_index = torch.from_numpy(edges) data.train_mask = torch.from_numpy(train_mask) data.test_mask = torch.from_numpy(test_mask) data.val_mask = torch.from_numpy(val_mask) data.x = torch.Tensor(data.x) data.y = torch.LongTensor(data.y) self.data = data self.num_classes = torch.max(self.data.y).item() + 1
def read_triplet_data(folder): filenames = ["train2id.txt", "valid2id.txt", "test2id.txt"] count = 0 edge_index = [] edge_attr = [] count_list = [] for filename in filenames: with open(osp.join(folder, filename), "r") as f: num = int(f.readline().strip()) for line in f: items = line.strip().split() edge_index.append([int(items[0]), int(items[1])]) edge_attr.append(int(items[2])) count += 1 count_list.append(count) edge_index = torch.LongTensor(edge_index).t() edge_attr = torch.LongTensor(edge_attr) data = Data() data.edge_index = edge_index data.edge_attr = edge_attr def generate_mask(start, end): mask = torch.BoolTensor(count) mask[:] = False mask[start:end] = True return mask data.train_mask = generate_mask(0, count_list[0]) data.val_mask = generate_mask(count_list[0], count_list[1]) data.test_mask = generate_mask(count_list[1], count_list[2]) return data
def read_triplet_data(folder): filenames = ["train2id.txt", "valid2id.txt", "test2id.txt"] count = 0 edge_index = [] edge_attr = [] count_list = [] triples = [] num_entities = 0 num_relations = 0 entity_dic = {} relation_dic = {} for filename in filenames: with open(osp.join(folder, filename), "r") as f: _ = int(f.readline().strip()) if "train" in filename: train_start_idx = len(triples) elif "valid" in filename: valid_start_idx = len(triples) elif "test" in filename: test_start_idx = len(triples) for line in f: items = line.strip().split() edge_index.append([int(items[0]), int(items[1])]) edge_attr.append(int(items[2])) triples.append((int(items[0]), int(items[2]), int(items[1]))) if items[0] not in entity_dic: entity_dic[items[0]] = num_entities num_entities += 1 if items[1] not in entity_dic: entity_dic[items[1]] = num_entities num_entities += 1 if items[2] not in relation_dic: relation_dic[items[2]] = num_relations num_relations += 1 count += 1 count_list.append(count) edge_index = torch.LongTensor(edge_index).t() edge_attr = torch.LongTensor(edge_attr) data = Data() data.edge_index = edge_index data.edge_attr = edge_attr def generate_mask(start, end): mask = torch.BoolTensor(count) mask[:] = False mask[start:end] = True return mask data.train_mask = generate_mask(0, count_list[0]) data.val_mask = generate_mask(count_list[0], count_list[1]) data.test_mask = generate_mask(count_list[1], count_list[2]) return data, triples, train_start_idx, valid_start_idx, test_start_idx, num_entities, num_relations