Пример #1
0
def train(args):
    ## load training data
    print "loading training data ......"
    node_num, class_num = removeIsolated(args.suffix)
    label, feature_map, adj_lists = collectGraph_train(node_num, class_num,
                                                       args.feat_dim,
                                                       args.num_sample,
                                                       args.suffix)
    label = torch.LongTensor(label)
    feature_map = torch.FloatTensor(feature_map)

    model = GAT(args.feat_dim, args.embed_dim, class_num, args.alpha,
                args.dropout, args.nheads, args.use_cuda)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.step_size,
                       gamma=args.learning_rate_decay)

    ## train
    np.random.seed(2)
    random.seed(2)
    rand_indices = np.random.permutation(node_num)
    train_nodes = rand_indices[:args.train_num]
    val_nodes = rand_indices[args.train_num:]

    if args.use_cuda:
        model.cuda()
        label = label.cuda()
        feature_map = feature_map.cuda()

    epoch_num = args.epoch_num
    batch_size = args.batch_size
    iter_num = int(math.ceil(args.train_num / float(batch_size)))
    check_loss = []
    val_accuracy = []
    check_step = args.check_step
    train_loss = 0.0
    iter_cnt = 0
    for e in range(epoch_num):
        model.train()
        scheduler.step()

        random.shuffle(train_nodes)
        for batch in range(iter_num):
            batch_nodes = train_nodes[batch * batch_size:(batch + 1) *
                                      batch_size]
            batch_label = label[batch_nodes].squeeze()
            batch_neighbors = [adj_lists[node] for node in batch_nodes]
            _, logit = model(feature_map, batch_nodes, batch_neighbors)
            loss = F.nll_loss(logit, batch_label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            iter_cnt += 1
            train_loss += loss.cpu().item()
            if iter_cnt % check_step == 0:
                check_loss.append(train_loss / check_step)
                print time.strftime(
                    '%Y-%m-%d %H:%M:%S'
                ), "epoch: {}, iter: {}, loss:{:.4f}".format(
                    e, iter_cnt, train_loss / check_step)
                train_loss = 0.0

        ## validation
        model.eval()

        group = int(math.ceil(len(val_nodes) / float(batch_size)))
        val_cnt = 0
        for batch in range(group):
            batch_nodes = val_nodes[batch * batch_size:(batch + 1) *
                                    batch_size]
            batch_label = label[batch_nodes].squeeze()
            batch_neighbors = [adj_lists[node] for node in batch_nodes]
            _, logit = model(feature_map, batch_nodes, batch_neighbors)
            batch_predict = np.argmax(logit.cpu().detach().numpy(), axis=1)
            val_cnt += np.sum(batch_predict == batch_label.cpu().numpy())
        val_accuracy.append(val_cnt / float(len(val_nodes)))
        print time.strftime('%Y-%m-%d %H:%M:%S'
                            ), "Epoch: {}, Validation Accuracy: {:.4f}".format(
                                e, val_cnt / float(len(val_nodes)))
        print "******" * 10

    checkpoint_path = 'checkpoint/checkpoint_{}.pth'.format(
        time.strftime('%Y%m%d%H%M'))
    torch.save(
        {
            'train_num': args.train_num,
            'epoch_num': args.epoch_num,
            'batch_size': args.batch_size,
            'learning_rate': args.learning_rate,
            'embed_dim': args.embed_dim,
            'num_sample': args.num_sample,
            'graph_state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, checkpoint_path)

    vis = visdom.Visdom(env='GraphAttention', port='8099')
    vis.line(X=np.arange(1,
                         len(check_loss) + 1, 1) * check_step,
             Y=np.array(check_loss),
             opts=dict(title=time.strftime('%Y-%m-%d %H:%M:%S'),
                       xlabel='itr.',
                       ylabel='loss'))
    vis.line(X=np.arange(1,
                         len(val_accuracy) + 1, 1),
             Y=np.array(val_accuracy),
             opts=dict(title=time.strftime('%Y-%m-%d %H:%M:%S'),
                       xlabel='epoch',
                       ylabel='accuracy'))

    return checkpoint_path, class_num
Пример #2
0
def test(checkpoint_path, class_num, args):

    model = GAT(args.feat_dim, args.embed_dim, class_num, args.alpha,
                args.dropout, args.nheads, args.use_cuda)

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['graph_state_dict'])
    if args.use_cuda:
        model.cuda()
    model.eval()

    for key in building.keys():
        node_num = test_dataset[key]['node_num']
        old_feature_map, adj_lists = collectGraph_test(
            test_dataset[key]['feature_path'], node_num, args.feat_dim,
            args.num_sample, args.suffix)
        old_feature_map = torch.FloatTensor(old_feature_map)
        if args.use_cuda:
            old_feature_map = old_feature_map.cuda()

        batch_num = int(math.ceil(node_num / float(args.batch_size)))
        new_feature_map = torch.FloatTensor()
        for batch in tqdm(range(batch_num)):
            start_node = batch * args.batch_size
            end_node = min((batch + 1) * args.batch_size, node_num)
            batch_nodes = range(start_node, end_node)
            batch_neighbors = [adj_lists[node] for node in batch_nodes]
            new_feature, _ = model(old_feature_map, batch_nodes,
                                   batch_neighbors)
            new_feature = F.normalize(new_feature, p=2, dim=1)
            new_feature_map = torch.cat(
                (new_feature_map, new_feature.cpu().detach()), dim=0)
        new_feature_map = new_feature_map.numpy()
        old_similarity = np.dot(old_feature_map.cpu().numpy(),
                                old_feature_map.cpu().numpy().T)
        new_similarity = np.dot(new_feature_map, new_feature_map.T)
        mAP_old = building[key].evalRetrieval(old_similarity, retrieval_result)
        mAP_new = building[key].evalRetrieval(new_similarity, retrieval_result)
        print time.strftime('%Y-%m-%d %H:%M:%S'), 'eval {}'.format(key)
        print 'base feature: {}, new feature: {}'.format(
            old_feature_map.size(), new_feature_map.shape)
        print 'base mAP: {:.4f}, new mAP: {:.4f}, improve: {:.4f}'.format(
            mAP_old, mAP_new, mAP_new - mAP_old)

        ## directly update node's features by mean pooling features of its neighbors.
        meanAggregator = model.attentions[0]
        mean_feature_map = torch.FloatTensor()
        for batch in tqdm(range(batch_num)):
            start_node = batch * args.batch_size
            end_node = min((batch + 1) * args.batch_size, node_num)
            batch_nodes = range(start_node, end_node)
            batch_neighbors = [adj_lists[node] for node in batch_nodes]
            mean_feature = meanAggregator.meanAggregate(
                old_feature_map, batch_nodes, batch_neighbors)
            mean_feature = F.normalize(mean_feature, p=2, dim=1)
            mean_feature_map = torch.cat(
                (mean_feature_map, mean_feature.cpu().detach()), dim=0)
        mean_feature_map = mean_feature_map.numpy()
        mean_similarity = np.dot(mean_feature_map, mean_feature_map.T)
        mAP_mean = building[key].evalRetrieval(mean_similarity,
                                               retrieval_result)
        print 'mean aggregation mAP: {:.4f}'.format(mAP_mean)
        print ""