Example #1
0
    def __init__(self, args=None):
        dataset = "jknet_cora"
        path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
        if not osp.exists(path):
            os.makedirs(path)
        super(CoraDataset, self).__init__(path)
        with open(self.processed_paths[0], 'rb') as fin:
            load_data = pickle.load(fin)
        self.num_nodes = load_data['node_num']

        data = Data()
        data.x = load_data['xs']
        data.y = load_data['ys']

        train_size = int(self.num_nodes * 0.8)
        train_mask = np.zeros((self.num_nodes, ), dtype=bool)
        train_idx = np.random.choice(np.arange(self.num_nodes), size=train_size, replace=False)
        train_mask[train_idx] = True
        test_mask = np.ones((self.num_nodes, ), dtype=bool)
        test_mask[train_idx] = False
        val_mask = test_mask

        edges = load_data['edges']
        edges = np.array(edges, dtype=int).transpose((1, 0))

        data.edge_index = torch.from_numpy(edges)
        data.train_mask = torch.from_numpy(train_mask)
        data.test_mask = torch.from_numpy(test_mask)
        data.val_mask = torch.from_numpy(val_mask)
        data.x = torch.Tensor(data.x)
        data.y = torch.LongTensor(data.y)

        self.data = data
        self.num_classes = torch.max(self.data.y).item() + 1
Example #2
0
def read_triplet_data(folder):
    filenames = ["train2id.txt", "valid2id.txt", "test2id.txt"]
    count = 0
    edge_index = []
    edge_attr = []
    count_list = []
    for filename in filenames:
        with open(osp.join(folder, filename), "r") as f:
            num = int(f.readline().strip())
            for line in f:
                items = line.strip().split()
                edge_index.append([int(items[0]), int(items[1])])
                edge_attr.append(int(items[2]))
                count += 1
            count_list.append(count)

    edge_index = torch.LongTensor(edge_index).t()
    edge_attr = torch.LongTensor(edge_attr)
    data = Data()
    data.edge_index = edge_index
    data.edge_attr = edge_attr

    def generate_mask(start, end):
        mask = torch.BoolTensor(count)
        mask[:] = False
        mask[start:end] = True
        return mask

    data.train_mask = generate_mask(0, count_list[0])
    data.val_mask = generate_mask(count_list[0], count_list[1])
    data.test_mask = generate_mask(count_list[1], count_list[2])
    return data
Example #3
0
File: kg_data.py Project: zrt/cogdl
def read_triplet_data(folder):
    filenames = ["train2id.txt", "valid2id.txt", "test2id.txt"]
    count = 0
    edge_index = []
    edge_attr = []
    count_list = []
    triples = []
    num_entities = 0
    num_relations = 0
    entity_dic = {}
    relation_dic = {}
    for filename in filenames:
        with open(osp.join(folder, filename), "r") as f:
            _ = int(f.readline().strip())
            if "train" in filename:
                train_start_idx = len(triples)
            elif "valid" in filename:
                valid_start_idx = len(triples)
            elif "test" in filename:
                test_start_idx = len(triples)
            for line in f:
                items = line.strip().split()
                edge_index.append([int(items[0]), int(items[1])])
                edge_attr.append(int(items[2]))
                triples.append((int(items[0]), int(items[2]), int(items[1])))
                if items[0] not in entity_dic:
                    entity_dic[items[0]] = num_entities
                    num_entities += 1
                if items[1] not in entity_dic:
                    entity_dic[items[1]] = num_entities
                    num_entities += 1
                if items[2] not in relation_dic:
                    relation_dic[items[2]] = num_relations
                    num_relations += 1
                count += 1
            count_list.append(count)

    edge_index = torch.LongTensor(edge_index).t()
    edge_attr = torch.LongTensor(edge_attr)
    data = Data()
    data.edge_index = edge_index
    data.edge_attr = edge_attr

    def generate_mask(start, end):
        mask = torch.BoolTensor(count)
        mask[:] = False
        mask[start:end] = True
        return mask

    data.train_mask = generate_mask(0, count_list[0])
    data.val_mask = generate_mask(count_list[0], count_list[1])
    data.test_mask = generate_mask(count_list[1], count_list[2])
    return data, triples, train_start_idx, valid_start_idx, test_start_idx, num_entities, num_relations