def extract_dataset():
    parser = argparse.ArgumentParser(description='DATA')
    register_data_args(parser)
    args = parser.parse_args()
    dataset_name = [
        'cora', 'citeseer', 'pubmed', 'reddit', 'CoraFull', 'Coauthor_cs',
        'Coauthor_physics', 'AmazonCoBuy_computers', 'AmazonCoBuy_photo'
    ]

    print("Now PATH IS ", os.getcwd())
    for name in dataset_name:
        '''
        if os.path.exists(name):
            print('Folder exists. Skipping ' + name)
            continue
        '''
        if name in ['cora', 'citeseer', 'pubmed', 'reddit']:

            args.dataset = name
            print('args.dataset = ', args.dataset)
            if not os.path.exists(name):
                os.mkdir(name)
            os.chdir(name)

            print("Now PATH IS ", os.getcwd())

            data = load_data(args)
            features = data.features
            labels = data.labels
            graph = data.graph
            edges = graph.edges
            train_mask = data.train_mask
            val_mask = data.val_mask
            test_mask = data.test_mask

            n_nodes = features.shape[0]
            n_edges = data.graph.number_of_edges

            if args.dataset == 'reddit':
                graph, features, labels, train_mask, val_mask, test_mask = cut_graph(
                    graph, n_nodes, n_edges, features, labels, train_mask,
                    val_mask, test_mask, 0.85)

            #edge_x = np.append(edge_x, edge_y, axis=1)

            edges_list = np.array([])
            first_element = True
            if name is not 'reddit':
                for item in edges:

                    if first_element:
                        edges_list = np.array([[item[0], item[1]]])
                        first_element = False
                    else:

                        edges_list = np.append(edges_list,
                                               np.array([[item[0], item[1]]]),
                                               axis=0)

            if name == 'reddit':
                edges = graph.edges()

                edge_x = edges[0].numpy().reshape((-1, 1))
                print(edge_x.shape)
                edge_y = edges[1].numpy().reshape((-1, 1))
                edges_list = np.hstack((edge_x, edge_y))
                print(edges_list.shape, edge_x.shape, edge_y.shape)

            print('features_shape', features.shape)
            print('labels_shape', labels.shape)
            print('edges_shape', edges_list.shape)
            '''
            np.savetxt('edges.txt', edges_list)
            np.savetxt('features.txt', features)
            np.savetxt('labels.txt', labels)

            np.savetxt('train_mask.txt', train_mask)
            np.savetxt('val_mask.txt', val_mask)
            np.savetxt('test_mask.txt', test_mask)
            '''

            np.save('edges.npy', edges_list)
            np.save('features.npy', features)
            np.save('labels.npy', labels)

            np.save('train_mask.npy', train_mask)

            print('Finish writing dataset', name)
            os.chdir('..')
            print('change to ', os.getcwd())

        else:

            if not os.path.exists(name):
                os.mkdir(name)
            os.chdir(name)

            if name == 'CoraFull':
                data = CoraFull()
            elif name == 'Coauthor_cs':
                data = Coauthor('cs')
            elif name == 'Coauthor_physics':
                data = Coauthor('physics')
            elif name == 'AmazonCoBuy_computers':
                data = AmazonCoBuy('computers')
            elif name == 'AmazonCoBuy_photo':
                data = AmazonCoBuy('photo')
            else:
                raise Exception("No such a dataset {}".format(name))

            graph = data.data[0]
            features = torch.FloatTensor(graph.ndata['feat']).numpy()
            labels = torch.LongTensor(graph.ndata['label']).numpy()

            print('dataset ', name)

            features_shape = features.shape
            labels_shape = labels.shape

            n_nodes = features_shape[0]
            edges_u, edges_v = graph.all_edges()

            edges_u = edges_u.numpy()
            edges_v = edges_v.numpy()

            edges_list = np.array([])
            first_element = True
            for idx in range(len(edges_u)):
                if first_element:
                    edges_list = np.array([[edges_u[idx], edges_v[idx]]])
                    first_element = False
                else:
                    edges_list = np.append(edges_list,
                                           np.array(
                                               [[edges_u[idx], edges_v[idx]]]),
                                           axis=0)

            print('features_shape', features_shape)
            print('labels_shape', labels_shape)
            print('edges_shape', edges_list.shape)

            train_mask = []
            for x in range(500):
                train_mask.append(True)
            for x in range(n_nodes - 500):
                train_mask.append(False)
            train_mask = np.array(train_mask)
            '''
            np.savetxt('edges.txt', edges_list)
            np.savetxt('features.txt', features)
            np.savetxt('labels.txt', labels)
            np.savetxt('train_mask.txt', train_mask)
            '''

            np.save('edges.npy', edges_list)
            np.save('features.npy', features)
            np.save('labels.npy', labels)
            np.save('train_mask.npy', train_mask)

            print('Finish writing dataset', name)
            os.chdir('..')
            print('change to ', os.getcwd())
 def __init__(
     self,
     rw_hops=64,
     subgraph_size=64,
     restart_prob=0.8,
     positional_embedding_size=32,
     step_dist=[1.0, 0.0, 0.0],
 ):
     super(GraphDataset).__init__()
     self.rw_hops = rw_hops
     self.subgraph_size = subgraph_size
     self.restart_prob = restart_prob
     self.positional_embedding_size = positional_embedding_size
     self.step_dist = step_dist
     assert sum(step_dist) == 1.0
     assert positional_embedding_size > 1
     #  graphs = []
     graphs, _ = dgl.data.utils.load_graphs(
         "data_bin/dgl/lscc_graphs.bin", [0, 1, 2]
     )
     for name in ["cs", "physics"]:
         g = Coauthor(name)[0]
         g.remove_nodes((g.in_degrees() == 0).nonzero().squeeze())
         g.readonly()
         graphs.append(g)
     for name in ["computers", "photo"]:
         g = AmazonCoBuy(name)[0]
         g.remove_nodes((g.in_degrees() == 0).nonzero().squeeze())
         g.readonly()
         graphs.append(g)
     # more graphs are comming ...
     print("load graph done")
     self.graphs = graphs
     self.length = sum([g.number_of_nodes() for g in self.graphs])
def main(args):

    torch.manual_seed(1234)

    if args.dataset == 'cora' or args.dataset == 'citeseer' or args.dataset == 'pubmed':
        data = load_data(args)
        features = torch.FloatTensor(data.features)

        labels = torch.LongTensor(data.labels)
        in_feats = features.shape[1]
        g = data.graph
        if args.dataset == 'cora':
            g.remove_edges_from(nx.selfloop_edges(g))
            g.add_edges_from(zip(g.nodes(), g.nodes()))
        g = DGLGraph(g)
        attr_matrix = data.features
        labels = data.labels

    else:
        if args.dataset == 'physics':
            data = Coauthor('physics')
        if args.dataset == 'cs':
            data = Coauthor('cs')
        if args.dataset == 'computers':
            data = AmazonCoBuy('computers')
        if args.dataset == 'photo':
            data = AmazonCoBuy('photo')

        g = data
        g = data[0]
        attr_matrix = g.ndata['feat']
        labels = g.ndata['label']

        features = torch.FloatTensor(g.ndata['feat'])

    ### LCC of the graph
    n_components = 1
    sparse_graph = g.adjacency_matrix_scipy(return_edge_ids=False)
    _, component_indices = sp.csgraph.connected_components(sparse_graph)
    component_sizes = np.bincount(component_indices)
    components_to_keep = np.argsort(
        component_sizes
    )[::-1][:n_components]  # reverse order to sort descending
    nodes_to_keep = [
        idx for (idx, component) in enumerate(component_indices)
        if component in components_to_keep
    ]

    adj_matrix = sparse_graph[nodes_to_keep][:, nodes_to_keep]
    num_nodes = len(nodes_to_keep)
    g = adj_matrix
    g = DGLGraph(g)
    g = remove_self_loop(g)
    g = add_self_loop(g)
    g = DGLGraph(g)

    g.ndata['feat'] = attr_matrix[nodes_to_keep]
    features = torch.FloatTensor(g.ndata['feat'].float())
    if args.dataset == 'cora' or args.dataset == 'pubmed':
        features = features / (features.norm(dim=1) + 1e-8)[:, None]
    g.ndata['label'] = labels[nodes_to_keep]
    labels = torch.LongTensor(g.ndata['label'])

    in_feats = features.shape[1]

    unique_l = np.unique(labels, return_counts=False)
    n_classes = len(unique_l)
    n_nodes = g.number_of_nodes()
    n_edges = g.number_of_edges()

    print('Number of nodes', n_nodes, 'Number of edges', n_edges)

    enc = OneHotEncoder()
    enc.fit(labels.reshape(-1, 1))
    ylabels = enc.transform(labels.reshape(-1, 1)).toarray()

    for beta in [args.beta]:
        for K in [args.num_clusters]:
            for alpha in [args.alpha]:
                accs = []
                t_st = time.time()

                sets = "imbalanced"

                for k in range(2):  #number of differnet trainings
                    #print(k)

                    random_state = np.random.RandomState()
                    if sets == "imbalanced":
                        train_idx, val_idx, test_idx = get_train_val_test_split(
                            random_state,
                            ylabels,
                            train_examples_per_class=None,
                            val_examples_per_class=None,
                            test_examples_per_class=None,
                            train_size=20 * n_classes,
                            val_size=30 * n_classes,
                            test_size=None)
                    elif sets == "balanced":
                        train_idx, val_idx, test_idx = get_train_val_test_split(
                            random_state,
                            ylabels,
                            train_examples_per_class=20,
                            val_examples_per_class=30,
                            test_examples_per_class=None,
                            train_size=None,
                            val_size=None,
                            test_size=None)
                    else:
                        ("No such set configuration (imbalanced/balanced)")

                    n_nodes = len(nodes_to_keep)
                    train_mask = np.zeros(n_nodes)
                    train_mask[train_idx] = 1
                    val_mask = np.zeros(n_nodes)
                    val_mask[val_idx] = 1
                    test_mask = np.zeros(n_nodes)
                    test_mask[test_idx] = 1
                    train_mask = torch.BoolTensor(train_mask)
                    val_mask = torch.BoolTensor(val_mask)
                    test_mask = torch.BoolTensor(test_mask)
                    """
                    Planetoid Split for CORA, CiteSeer, PubMed
                    train_mask = torch.BoolTensor(data.train_mask)
                    val_mask = torch.BoolTensor(data.val_mask)
                    test_mask = torch.BoolTensor(data.test_mask)
                    train_mask2 = torch.BoolTensor(data.train_mask)
                    val_mask2 = torch.BoolTensor(data.val_mask)
                    test_mask2 = torch.BoolTensor(data.test_mask)
                    """

                    if args.gpu < 0:
                        cuda = False

                    else:
                        cuda = True
                        torch.cuda.set_device(args.gpu)
                        features = features.cuda()
                        labels = labels.cuda()
                        train_mask = train_mask.cuda()
                        val_mask = val_mask.cuda()
                        test_mask = test_mask.cuda()

                    gic = GIC(g, in_feats, args.n_hidden, args.n_layers,
                              nn.PReLU(args.n_hidden), args.dropout, K, beta,
                              alpha)

                    if cuda:
                        gic.cuda()

                    gic_optimizer = torch.optim.Adam(
                        gic.parameters(),
                        lr=args.gic_lr,
                        weight_decay=args.weight_decay)

                    # train GIC
                    cnt_wait = 0
                    best = 1e9
                    best_t = 0
                    dur = []

                    for epoch in range(args.n_gic_epochs):
                        gic.train()
                        if epoch >= 3:
                            t0 = time.time()

                        gic_optimizer.zero_grad()
                        loss = gic(features)
                        #print(loss)
                        loss.backward()
                        gic_optimizer.step()

                        if loss < best:
                            best = loss
                            best_t = epoch
                            cnt_wait = 0
                            torch.save(gic.state_dict(), 'best_gic.pkl')
                        else:
                            cnt_wait += 1

                        if cnt_wait == args.patience:
                            #print('Early stopping!')
                            break

                        if epoch >= 3:
                            dur.append(time.time() - t0)

                        #print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
                        #"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
                        #n_edges / np.mean(dur) / 1000))

                    # train classifier
                    #print('Loading {}th epoch'.format(best_t))
                    gic.load_state_dict(torch.load('best_gic.pkl'))
                    embeds = gic.encoder(features, corrupt=False)
                    embeds = embeds / (embeds + 1e-8).norm(dim=1)[:, None]
                    embeds = embeds.detach()

                    # create classifier model
                    classifier = Classifier(args.n_hidden, n_classes)
                    if cuda:
                        classifier.cuda()

                    classifier_optimizer = torch.optim.Adam(
                        classifier.parameters(),
                        lr=args.classifier_lr,
                        weight_decay=args.weight_decay)

                    dur = []
                    best_a = 0
                    cnt_wait = 0
                    for epoch in range(args.n_classifier_epochs):
                        classifier.train()
                        if epoch >= 3:
                            t0 = time.time()

                        classifier_optimizer.zero_grad()
                        preds = classifier(embeds)
                        loss = F.nll_loss(preds[train_mask],
                                          labels[train_mask])
                        loss.backward()
                        classifier_optimizer.step()

                        if epoch >= 3:
                            dur.append(time.time() - t0)

                        acc = evaluate(
                            classifier, embeds, labels, val_mask
                        )  #+ evaluate(classifier, embeds, labels, train_mask)

                        if acc > best_a and epoch > 100:
                            best_a = acc
                            best_t = epoch

                            torch.save(classifier.state_dict(),
                                       'best_class.pkl')

                        #print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
                        #"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
                        #acc, n_edges / np.mean(dur) / 1000))

                    acc = evaluate(classifier, embeds, labels, test_mask)
                    accs.append(acc)

                print('=================== ', ' alpha', alpha, ' beta ', beta,
                      'K', K)
                print(args.dataset, ' Acc (mean)', mean(accs), ' (std)',
                      stdev(accs))
                print('=================== time', int(
                    (time.time() - t_st) / 60))