Пример #1
0
def get_tg_dataset(args, dataset_name, use_cache=True, remove_feature=False):
    # "Cora", "CiteSeer" and "PubMed"
    if dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
        dataset = tg.datasets.Planetoid(root='datasets/' + dataset_name,
                                        name=dataset_name)
    else:
        try:
            dataset = load_tg_dataset(dataset_name)
        except:
            raise NotImplementedError

    # precompute shortest path
    if not os.path.isdir('datasets'):
        os.mkdir('datasets')
    if not os.path.isdir('datasets/cache'):
        os.mkdir('datasets/cache')
    f1_name = 'datasets/cache/' + dataset_name + str(
        args.approximate) + '_dists.dat'
    f2_name = 'datasets/cache/' + dataset_name + str(
        args.approximate) + '_dists_removed.dat'
    f3_name = 'datasets/cache/' + dataset_name + str(
        args.approximate) + '_links_train.dat'
    f4_name = 'datasets/cache/' + dataset_name + str(
        args.approximate) + '_links_val.dat'
    f5_name = 'datasets/cache/' + dataset_name + str(
        args.approximate) + '_links_test.dat'

    if use_cache and ((os.path.isfile(f2_name) and args.task == 'link') or
                      (os.path.isfile(f1_name) and args.task != 'link')):
        with open(f3_name, 'rb') as f3, \
            open(f4_name, 'rb') as f4, \
            open(f5_name, 'rb') as f5:
            links_train_list = pickle.load(f3)
            links_val_list = pickle.load(f4)
            links_test_list = pickle.load(f5)
        if args.task == 'link':
            with open(f2_name, 'rb') as f2:
                dists_removed_list = pickle.load(f2)
        else:
            with open(f1_name, 'rb') as f1:
                dists_list = pickle.load(f1)

        print('Cache loaded!')
        data_list = []
        for i, data in enumerate(dataset):
            if args.task == 'link':
                data.mask_link_positive = deduplicate_edges(
                    data.edge_index.numpy())
            data.mask_link_positive_train = links_train_list[i]
            data.mask_link_positive_val = links_val_list[i]
            data.mask_link_positive_test = links_test_list[i]
            get_link_mask(data, resplit=False)

            if args.task == 'link':
                data.dists = torch.from_numpy(dists_removed_list[i]).float()
                data.edge_index = torch.from_numpy(
                    duplicate_edges(data.mask_link_positive_train)).long()
            else:
                data.dists = torch.from_numpy(dists_list[i]).float()
            if remove_feature:
                data.x = torch.ones((data.x.shape[0], 1))
            data_list.append(data)
    else:
        data_list = []
        dists_list = []
        dists_removed_list = []
        links_train_list = []
        links_val_list = []
        links_test_list = []
        for i, data in enumerate(dataset):
            if 'link' in args.task:
                print(f"args.task = {args.task}")
                get_link_mask(
                    data,
                    args.remove_link_ratio,
                    resplit=True,
                    infer_link_positive=True if args.task == 'link' else False)
            links_train_list.append(data.mask_link_positive_train)
            links_val_list.append(data.mask_link_positive_val)
            links_test_list.append(data.mask_link_positive_test)
            if args.task == 'link':
                dists_removed = precompute_dist_data(
                    data.mask_link_positive_train,
                    data.num_nodes,
                    approximate=args.approximate)
                dists_removed_list.append(dists_removed)
                data.dists = torch.from_numpy(dists_removed).float()
                data.edge_index = torch.from_numpy(
                    duplicate_edges(data.mask_link_positive_train)).long()

            else:
                dists = precompute_dist_data(data.edge_index.numpy(),
                                             data.num_nodes,
                                             approximate=args.approximate)
                dists_list.append(dists)
                data.dists = torch.from_numpy(dists).float()
            if remove_feature:
                data.x = torch.ones((data.x.shape[0], 1))
            data_list.append(data)

        with open(f1_name, 'wb') as f1, \
            open(f2_name, 'wb') as f2, \
            open(f3_name, 'wb') as f3, \
            open(f4_name, 'wb') as f4, \
            open(f5_name, 'wb') as f5:

            if args.task == 'link':
                pickle.dump(dists_removed_list, f2)
            else:
                pickle.dump(dists_list, f1)
            pickle.dump(links_train_list, f3)
            pickle.dump(links_val_list, f4)
            pickle.dump(links_test_list, f5)
        print('Cache saved!')
    return data_list
Пример #2
0
def get_tg_dataset(args, dataset_name, use_cache=True, remove_feature=False, hash_overwrite=False, hash_concat=False):
    # "Cora", "CiteSeer" and "PubMed"
    if dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
        dataset = tg.datasets.Planetoid(root='datasets/' + dataset_name, name=dataset_name)
    else:
        try:
            dataset = load_tg_dataset(dataset_name)
        except:
            raise NotImplementedError

    # precompute shortest path
    if not os.path.isdir('datasets/cache'):
        os.mkdir('datasets/cache')
    f1_name = 'datasets/cache/' + dataset_name + str(args.approximate) + '_dists.dat'
    f2_name = 'datasets/cache/' + dataset_name + str(args.approximate)+ '_dists_removed.dat'
    f3_name = 'datasets/cache/' + dataset_name + str(args.approximate)+ '_links_train.dat'
    f4_name = 'datasets/cache/' + dataset_name + str(args.approximate)+ '_links_val.dat'
    f5_name = 'datasets/cache/' + dataset_name + str(args.approximate)+ '_links_test.dat'
    # cache for dists_all
    f6_name = 'datasets/cache/' + dataset_name + str(args.approximate)+ '_dists_all.dat'

    if use_cache and ((os.path.isfile(f2_name) and args.task=='link') or (os.path.isfile(f1_name) and
                                                                          args.task != 'link')):
        with open(f3_name, 'rb') as f3, \
                open(f4_name, 'rb') as f4, \
                open(f5_name, 'rb') as f5, \
                open(f6_name, 'rb') as f6:

            links_train_list = pickle.load(f3)
            links_val_list = pickle.load(f4)
            links_test_list = pickle.load(f5)
            # load a list of dists_all (each is for one connected components)
            dists_all_list = pickle.load(f6)
        if args.task=='link':
            with open(f2_name, 'rb') as f2:
                dists_removed_list = pickle.load(f2)
        else:
            with open(f1_name, 'rb') as f1:
                dists_list = pickle.load(f1)

        print('Cache loaded!')
        data_list = []
        start_node = 0
        for i, data in enumerate(dataset):
            if args.task == 'link':
                data.mask_link_positive = deduplicate_edges(data.edge_index.numpy())
            data.mask_link_positive_train = links_train_list[i]
            data.mask_link_positive_val = links_val_list[i]
            data.mask_link_positive_test = links_test_list[i]
            get_link_mask(data, resplit=False)

            if args.task == 'link':
                data.dists = torch.from_numpy(dists_removed_list[i]).float()
                data.edge_index = torch.from_numpy(duplicate_edges(data.mask_link_positive_train)).long()
            else:
                data.dists = torch.from_numpy(dists_list[i]).float()
            if remove_feature:
                data.x = torch.ones((data.x.shape[0],1))
            if hash_overwrite:
                x = np.zeros(data.x.shape)
                for m in range(data.x.shape[0]):
                    x[m] = int_to_hash_vector(start_node + m, data.x.shape[1])
                data.x = torch.from_numpy(x).toFloat()
                start_node += data.x.shape[0]
            if hash_concat:
                x = np.zeros((data.x.shape[0], data.x.shape[1] * 2))
                for m in range(data.x.shape[0]):
                    x[m] = np.concatenate((data.x[m], int_to_hash_vector(start_node + m, data.x.shape[1])))
                data.x = torch.from_numpy(x).toFloat()
                start_node += data.x.shape[0]

            # assign dists_all to each data (connected components)
            data.dists_all = dists_all_list[i]

            # generate graph dist ranks
            data.dists_ranks = gen_graph_dist_rank_data(data.dists_all)

            data_list.append(data)
    else:
        dists_all_list = []  # dists_all stores dists for all nodes (regardless whether it's in train, val or test)
        data_list = []
        dists_list = []
        dists_removed_list = []
        links_train_list = []
        links_val_list = []
        links_test_list = []
        start_node = 0
        for i, data in enumerate(dataset):
            if 'link' in args.task:
                get_link_mask(data, args.remove_link_ratio, resplit=True,
                              infer_link_positive=True if args.task == 'link' else False)
            links_train_list.append(data.mask_link_positive_train)
            links_val_list.append(data.mask_link_positive_val)
            links_test_list.append(data.mask_link_positive_test)
            if args.task=='link':
                dists_removed = precompute_dist_data(data.mask_link_positive_train, data.num_nodes,
                                                     approximate=args.approximate)
                dists_removed_list.append(dists_removed)
                data.dists = torch.from_numpy(dists_removed).float()
                data.edge_index = torch.from_numpy(duplicate_edges(data.mask_link_positive_train)).long()

            else:
                dists = precompute_dist_data(data.edge_index.numpy(), data.num_nodes, approximate=args.approximate)
                dists_list.append(dists)
                data.dists = torch.from_numpy(dists).float()

            # calculate dists for all nodes in the connected component, no need to worry about if the task is 'link'
            dists_all = precompute_dist_data(data.edge_index.numpy(), data.num_nodes, approximate=args.approximate)
            dists_all_list.append(dists_all)
            data.dists_all = dists_all

            if remove_feature:
                data.x = torch.ones((data.x.shape[0],1))
            if hash_overwrite:
                x = np.zeros(data.x.shape)
                for m in range(data.x.shape[0]):
                    x[m] = int_to_hash_vector(start_node + m, data.x.shape[1])
                data.x = torch.from_numpy(x).float()
                start_node += data.x.shape[0]
            if hash_concat:
                x = np.zeros((data.x.shape[0], data.x.shape[1] * 2))
                for m in range(data.x.shape[0]):
                    x[m] = np.concatenate((data.x[m], int_to_hash_vector(start_node + m, data.x.shape[1])))
                data.x = torch.from_numpy(x).float()
                start_node += data.x.shape[0]

            # generate graph dist ranks
            data.dists_ranks = gen_graph_dist_rank_data(data.dists_all)
            data_list.append(data)

        with open(f1_name, 'wb') as f1, \
                open(f2_name, 'wb') as f2, \
                open(f3_name, 'wb') as f3, \
                open(f4_name, 'wb') as f4, \
                open(f5_name, 'wb') as f5, \
                open(f6_name, 'wb') as f6:

            if args.task == 'link':
                pickle.dump(dists_removed_list, f2)
            else:
                pickle.dump(dists_list, f1)
            pickle.dump(links_train_list, f3)
            pickle.dump(links_val_list, f4)
            pickle.dump(links_test_list, f5)
            pickle.dump(dists_all_list, f6)
        print('Cache saved!')

    return data_list