예제 #1
0
def read_planetoid_data(folder, prefix):
    prefix = prefix.lower()
    names = ["x", "tx", "allx", "y", "ty", "ally", "graph", "test.index"]
    objects = []
    for item in names[:-1]:
        with open(f"{folder}/ind.{prefix}.{item}", "rb") as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding="latin1"))
            else:
                objects.append(pkl.load(f))
    test_index = parse_index_file(f"{folder}/ind.{prefix}.{names[-1]}")
    test_index = torch.Tensor(test_index).long()
    test_index_reorder = test_index.sort()[0]

    x, tx, allx, y, ty, ally, graph = tuple(objects)
    x, tx, allx = tuple(
        [torch.from_numpy(item.todense()).float() for item in [x, tx, allx]])
    y, ty, ally = tuple(
        [torch.from_numpy(item).float() for item in [y, ty, ally]])

    train_index = torch.arange(y.size(0), dtype=torch.long)
    val_index = torch.arange(y.size(0), y.size(0) + 500, dtype=torch.long)

    if prefix.lower() == "citeseer":
        # There are some isolated nodes in the Citeseer graph, resulting in
        # none consecutive test indices. We need to identify them and add them
        # as zero vectors to `tx` and `ty`.
        len_test_indices = (test_index.max() - test_index.min()).item() + 1

        tx_ext = torch.zeros(len_test_indices, tx.size(1))
        tx_ext[test_index_reorder - test_index.min(), :] = tx
        ty_ext = torch.zeros(len_test_indices, ty.size(1))
        ty_ext[test_index_reorder - test_index.min(), :] = ty

        tx, ty = tx_ext, ty_ext

    x = torch.cat([allx, tx], dim=0).float()
    y = torch.cat([ally, ty], dim=0).max(dim=1)[1].long()

    x[test_index] = x[test_index_reorder]
    y[test_index] = y[test_index_reorder]

    train_mask = index_to_mask(train_index, size=y.size(0))
    val_mask = index_to_mask(val_index, size=y.size(0))
    test_mask = index_to_mask(test_index, size=y.size(0))

    edge_index = edge_index_from_dict(graph, num_nodes=y.size(0))

    data = Graph(x=x, edge_index=edge_index, y=y)
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    return data
예제 #2
0
def read_triplet_data(folder):
    filenames = ["train2id.txt", "valid2id.txt", "test2id.txt"]
    count = 0
    edge_index = []
    edge_attr = []
    count_list = []
    triples = []
    num_entities = 0
    num_relations = 0
    entity_dic = {}
    relation_dic = {}
    for filename in filenames:
        with open(osp.join(folder, filename), "r") as f:
            _ = int(f.readline().strip())
            if "train" in filename:
                train_start_idx = len(triples)
            elif "valid" in filename:
                valid_start_idx = len(triples)
            elif "test" in filename:
                test_start_idx = len(triples)
            for line in f:
                items = line.strip().split()
                edge_index.append([int(items[0]), int(items[1])])
                edge_attr.append(int(items[2]))
                triples.append((int(items[0]), int(items[2]), int(items[1])))
                if items[0] not in entity_dic:
                    entity_dic[items[0]] = num_entities
                    num_entities += 1
                if items[1] not in entity_dic:
                    entity_dic[items[1]] = num_entities
                    num_entities += 1
                if items[2] not in relation_dic:
                    relation_dic[items[2]] = num_relations
                    num_relations += 1
                count += 1
            count_list.append(count)

    edge_index = torch.LongTensor(edge_index).t()
    edge_attr = torch.LongTensor(edge_attr)
    data = Graph()
    data.edge_index = edge_index
    data.edge_attr = edge_attr

    def generate_mask(start, end):
        mask = torch.BoolTensor(count)
        mask[:] = False
        mask[start:end] = True
        return mask

    data.train_mask = generate_mask(0, count_list[0])
    data.val_mask = generate_mask(count_list[0], count_list[1])
    data.test_mask = generate_mask(count_list[1], count_list[2])
    return data, triples, train_start_idx, valid_start_idx, test_start_idx, num_entities, num_relations