def train(args): ## load training data print "loading training data ......" node_num, class_num = removeIsolated(args.suffix) label, feature_map, adj_lists = collectGraph_train(node_num, class_num, args.feat_dim, args.num_sample, args.suffix) label = torch.LongTensor(label) feature_map = torch.FloatTensor(feature_map) model = GAT(args.feat_dim, args.embed_dim, class_num, args.alpha, args.dropout, args.nheads, args.use_cuda) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = StepLR(optimizer, step_size=args.step_size, gamma=args.learning_rate_decay) ## train np.random.seed(2) random.seed(2) rand_indices = np.random.permutation(node_num) train_nodes = rand_indices[:args.train_num] val_nodes = rand_indices[args.train_num:] if args.use_cuda: model.cuda() label = label.cuda() feature_map = feature_map.cuda() epoch_num = args.epoch_num batch_size = args.batch_size iter_num = int(math.ceil(args.train_num / float(batch_size))) check_loss = [] val_accuracy = [] check_step = args.check_step train_loss = 0.0 iter_cnt = 0 for e in range(epoch_num): model.train() scheduler.step() random.shuffle(train_nodes) for batch in range(iter_num): batch_nodes = train_nodes[batch * batch_size:(batch + 1) * batch_size] batch_label = label[batch_nodes].squeeze() batch_neighbors = [adj_lists[node] for node in batch_nodes] _, logit = model(feature_map, batch_nodes, batch_neighbors) loss = F.nll_loss(logit, batch_label) optimizer.zero_grad() loss.backward() optimizer.step() iter_cnt += 1 train_loss += loss.cpu().item() if iter_cnt % check_step == 0: check_loss.append(train_loss / check_step) print time.strftime( '%Y-%m-%d %H:%M:%S' ), "epoch: {}, iter: {}, loss:{:.4f}".format( e, iter_cnt, train_loss / check_step) train_loss = 0.0 ## validation model.eval() group = int(math.ceil(len(val_nodes) / float(batch_size))) val_cnt = 0 for batch in range(group): batch_nodes = val_nodes[batch * batch_size:(batch + 1) * batch_size] batch_label = label[batch_nodes].squeeze() batch_neighbors = [adj_lists[node] for node in batch_nodes] _, logit = model(feature_map, batch_nodes, batch_neighbors) batch_predict = np.argmax(logit.cpu().detach().numpy(), axis=1) val_cnt += np.sum(batch_predict == batch_label.cpu().numpy()) val_accuracy.append(val_cnt / float(len(val_nodes))) print time.strftime('%Y-%m-%d %H:%M:%S' ), "Epoch: {}, Validation Accuracy: {:.4f}".format( e, val_cnt / float(len(val_nodes))) print "******" * 10 checkpoint_path = 'checkpoint/checkpoint_{}.pth'.format( time.strftime('%Y%m%d%H%M')) torch.save( { 'train_num': args.train_num, 'epoch_num': args.epoch_num, 'batch_size': args.batch_size, 'learning_rate': args.learning_rate, 'embed_dim': args.embed_dim, 'num_sample': args.num_sample, 'graph_state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, checkpoint_path) vis = visdom.Visdom(env='GraphAttention', port='8099') vis.line(X=np.arange(1, len(check_loss) + 1, 1) * check_step, Y=np.array(check_loss), opts=dict(title=time.strftime('%Y-%m-%d %H:%M:%S'), xlabel='itr.', ylabel='loss')) vis.line(X=np.arange(1, len(val_accuracy) + 1, 1), Y=np.array(val_accuracy), opts=dict(title=time.strftime('%Y-%m-%d %H:%M:%S'), xlabel='epoch', ylabel='accuracy')) return checkpoint_path, class_num
def test(checkpoint_path, class_num, args): model = GAT(args.feat_dim, args.embed_dim, class_num, args.alpha, args.dropout, args.nheads, args.use_cuda) checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['graph_state_dict']) if args.use_cuda: model.cuda() model.eval() for key in building.keys(): node_num = test_dataset[key]['node_num'] old_feature_map, adj_lists = collectGraph_test( test_dataset[key]['feature_path'], node_num, args.feat_dim, args.num_sample, args.suffix) old_feature_map = torch.FloatTensor(old_feature_map) if args.use_cuda: old_feature_map = old_feature_map.cuda() batch_num = int(math.ceil(node_num / float(args.batch_size))) new_feature_map = torch.FloatTensor() for batch in tqdm(range(batch_num)): start_node = batch * args.batch_size end_node = min((batch + 1) * args.batch_size, node_num) batch_nodes = range(start_node, end_node) batch_neighbors = [adj_lists[node] for node in batch_nodes] new_feature, _ = model(old_feature_map, batch_nodes, batch_neighbors) new_feature = F.normalize(new_feature, p=2, dim=1) new_feature_map = torch.cat( (new_feature_map, new_feature.cpu().detach()), dim=0) new_feature_map = new_feature_map.numpy() old_similarity = np.dot(old_feature_map.cpu().numpy(), old_feature_map.cpu().numpy().T) new_similarity = np.dot(new_feature_map, new_feature_map.T) mAP_old = building[key].evalRetrieval(old_similarity, retrieval_result) mAP_new = building[key].evalRetrieval(new_similarity, retrieval_result) print time.strftime('%Y-%m-%d %H:%M:%S'), 'eval {}'.format(key) print 'base feature: {}, new feature: {}'.format( old_feature_map.size(), new_feature_map.shape) print 'base mAP: {:.4f}, new mAP: {:.4f}, improve: {:.4f}'.format( mAP_old, mAP_new, mAP_new - mAP_old) ## directly update node's features by mean pooling features of its neighbors. meanAggregator = model.attentions[0] mean_feature_map = torch.FloatTensor() for batch in tqdm(range(batch_num)): start_node = batch * args.batch_size end_node = min((batch + 1) * args.batch_size, node_num) batch_nodes = range(start_node, end_node) batch_neighbors = [adj_lists[node] for node in batch_nodes] mean_feature = meanAggregator.meanAggregate( old_feature_map, batch_nodes, batch_neighbors) mean_feature = F.normalize(mean_feature, p=2, dim=1) mean_feature_map = torch.cat( (mean_feature_map, mean_feature.cpu().detach()), dim=0) mean_feature_map = mean_feature_map.numpy() mean_similarity = np.dot(mean_feature_map, mean_feature_map.T) mAP_mean = building[key].evalRetrieval(mean_similarity, retrieval_result) print 'mean aggregation mAP: {:.4f}'.format(mAP_mean) print ""