Example #1
0
def pretrain(dataset):
    model = GAT(
        num_features=args.input_dim,
        hidden_size=args.hidden_size,
        embedding_size=args.embedding_size,
        alpha=args.alpha,
    ).to(device)
    print(model)
    optimizer = Adam(model.parameters(),
                     lr=args.lr,
                     weight_decay=args.weight_decay)

    # data process
    dataset = utils.data_preprocessing(dataset)
    adj = dataset.adj.to(device)
    adj_label = dataset.adj_label.to(device)
    M = utils.get_M(adj).to(device)

    # data and label
    x = torch.Tensor(dataset.x).to(device)
    y = dataset.y.cpu().numpy()

    for epoch in range(args.max_epoch):
        model.train()
        A_pred, z = model(x, adj, M)
        loss = F.binary_cross_entropy(A_pred.view(-1), adj_label.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            _, z = model(x, adj, M)
            kmeans = KMeans(n_clusters=args.n_clusters,
                            n_init=20).fit(z.data.cpu().numpy())
            acc, nmi, ari, f1 = eva(y, kmeans.labels_, epoch)
        if epoch % 5 == 0:
            torch.save(model.state_dict(),
                       f"./pretrain/predaegc_{args.name}_{epoch}.pkl")
Example #2
0
def main(args):
    # load and preprocess dataset
    data = CoraGraphDataset()

    g = data[0]
    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        g = g.int().to(args.gpu)

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    num_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()

    x = g.nodes().cpu()
    print(features)
    print(g)
    x = input()
    print("""----Data statistics------'
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
      #Test samples %d""" %
          (n_edges, n_classes,
           train_mask.int().sum().item(),
           val_mask.int().sum().item(),
           test_mask.int().sum().item()))

    # add self loop
    g = dgl.remove_self_loop(g)
    g = dgl.add_self_loop(g)
    n_edges = g.number_of_edges()
    # create model
    heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
    model = GAT(args.num_layers,
                num_feats,
                args.num_hidden,
                n_classes,
                heads,
                F.elu,
                args.in_drop,
                args.attn_drop,
                args.negative_slope,
                args.residual)
    print(model)
    if args.early_stop:
        stopper = EarlyStopping(patience=100)
    if cuda:
        model.cuda()
    loss_fcn = torch.nn.CrossEntropyLoss()

    # use optimizer
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    # initialize graph
    dur = []
    for epoch in range(args.epochs):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
        logits = model(g, features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        train_acc = accuracy(logits[train_mask], labels[train_mask])

        if args.fastmode:
            val_acc = accuracy(logits[val_mask], labels[val_mask])
        else:
            val_acc = evaluate(model, g, features, labels, val_mask)
            if args.early_stop:
                if stopper.step(val_acc, model):
                    break

        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
              " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".
              format(epoch, np.mean(dur), loss.item(), train_acc,
                     val_acc, n_edges / np.mean(dur) / 1000))

    print()
    if args.early_stop:
        model.load_state_dict(torch.load('es_checkpoint.pt'))
    acc = evaluate(model, g, features, labels, test_mask)
    print("Test Accuracy {:.4f}".format(acc))
Example #3
0
def train(args):
    ## load training data
    print "loading training data ......"
    node_num, class_num = removeIsolated(args.suffix)
    label, feature_map, adj_lists = collectGraph_train(node_num, class_num,
                                                       args.feat_dim,
                                                       args.num_sample,
                                                       args.suffix)
    label = torch.LongTensor(label)
    feature_map = torch.FloatTensor(feature_map)

    model = GAT(args.feat_dim, args.embed_dim, class_num, args.alpha,
                args.dropout, args.nheads, args.use_cuda)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.step_size,
                       gamma=args.learning_rate_decay)

    ## train
    np.random.seed(2)
    random.seed(2)
    rand_indices = np.random.permutation(node_num)
    train_nodes = rand_indices[:args.train_num]
    val_nodes = rand_indices[args.train_num:]

    if args.use_cuda:
        model.cuda()
        label = label.cuda()
        feature_map = feature_map.cuda()

    epoch_num = args.epoch_num
    batch_size = args.batch_size
    iter_num = int(math.ceil(args.train_num / float(batch_size)))
    check_loss = []
    val_accuracy = []
    check_step = args.check_step
    train_loss = 0.0
    iter_cnt = 0
    for e in range(epoch_num):
        model.train()
        scheduler.step()

        random.shuffle(train_nodes)
        for batch in range(iter_num):
            batch_nodes = train_nodes[batch * batch_size:(batch + 1) *
                                      batch_size]
            batch_label = label[batch_nodes].squeeze()
            batch_neighbors = [adj_lists[node] for node in batch_nodes]
            _, logit = model(feature_map, batch_nodes, batch_neighbors)
            loss = F.nll_loss(logit, batch_label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            iter_cnt += 1
            train_loss += loss.cpu().item()
            if iter_cnt % check_step == 0:
                check_loss.append(train_loss / check_step)
                print time.strftime(
                    '%Y-%m-%d %H:%M:%S'
                ), "epoch: {}, iter: {}, loss:{:.4f}".format(
                    e, iter_cnt, train_loss / check_step)
                train_loss = 0.0

        ## validation
        model.eval()

        group = int(math.ceil(len(val_nodes) / float(batch_size)))
        val_cnt = 0
        for batch in range(group):
            batch_nodes = val_nodes[batch * batch_size:(batch + 1) *
                                    batch_size]
            batch_label = label[batch_nodes].squeeze()
            batch_neighbors = [adj_lists[node] for node in batch_nodes]
            _, logit = model(feature_map, batch_nodes, batch_neighbors)
            batch_predict = np.argmax(logit.cpu().detach().numpy(), axis=1)
            val_cnt += np.sum(batch_predict == batch_label.cpu().numpy())
        val_accuracy.append(val_cnt / float(len(val_nodes)))
        print time.strftime('%Y-%m-%d %H:%M:%S'
                            ), "Epoch: {}, Validation Accuracy: {:.4f}".format(
                                e, val_cnt / float(len(val_nodes)))
        print "******" * 10

    checkpoint_path = 'checkpoint/checkpoint_{}.pth'.format(
        time.strftime('%Y%m%d%H%M'))
    torch.save(
        {
            'train_num': args.train_num,
            'epoch_num': args.epoch_num,
            'batch_size': args.batch_size,
            'learning_rate': args.learning_rate,
            'embed_dim': args.embed_dim,
            'num_sample': args.num_sample,
            'graph_state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, checkpoint_path)

    vis = visdom.Visdom(env='GraphAttention', port='8099')
    vis.line(X=np.arange(1,
                         len(check_loss) + 1, 1) * check_step,
             Y=np.array(check_loss),
             opts=dict(title=time.strftime('%Y-%m-%d %H:%M:%S'),
                       xlabel='itr.',
                       ylabel='loss'))
    vis.line(X=np.arange(1,
                         len(val_accuracy) + 1, 1),
             Y=np.array(val_accuracy),
             opts=dict(title=time.strftime('%Y-%m-%d %H:%M:%S'),
                       xlabel='epoch',
                       ylabel='accuracy'))

    return checkpoint_path, class_num