def sgcn_plus_v2(feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes, args, device, calculate_grad_vars=False): # use multiprocess sample data process_ids = np.arange(args.batch_num) pool = mp.Pool(args.pool_num) lap_matrix_sq = lap_matrix.multiply(lap_matrix) jobs = prepare_data(pool, sampler, process_ids, train_nodes, samp_num_list, len(feat_data), lap_matrix, lap_matrix_sq, args.n_layers) susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes, layers=args.n_layers, dropout=args.dropout).to(device) susage.to(device) print(susage) adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler( train_nodes, len(feat_data), lap_matrix, args.n_layers) adjs_full = package_mxl(adjs_full, device) # this stupid wrapper is only used for sgcn++ forward_wrapper = ForwardWrapperMomentum(len(feat_data), args.nhid, args.n_layers, num_classes) optimizer = optim.SGD(filter(lambda p: p.requires_grad, susage.parameters()), lr=0.1) loss_train = [] loss_test = [] grad_variance_all = [] loss_train_all = [] best_model = copy.deepcopy(susage) best_val_loss = 10 # randomly pick a large number is fine best_val_index = 0 best_val_cnt = 0 for epoch in np.arange(args.epoch_num): # fetch train data train_data = [job.get() for job in jobs] pool.close() pool.join() # prepare next epoch train data pool = mp.Pool(args.pool_num) jobs = prepare_data(pool, sampler, process_ids, train_nodes, samp_num_list, len(feat_data), lap_matrix, lap_matrix_sq, args.n_layers) inner_loop_num = args.batch_num # compare with sgcn_plus, the only difference is we use multi_level_spider_step_v1 here cur_train_loss, cur_train_loss_all, grad_variance = multi_level_momentum_step( susage, optimizer, feat_data, labels, train_nodes, valid_nodes, adjs_full, sampled_nodes_full, train_data, inner_loop_num, forward_wrapper, device, calculate_grad_vars=calculate_grad_vars) loss_train_all.extend(cur_train_loss_all) grad_variance_all.extend(grad_variance) # calculate validate loss susage.eval() susage.zero_grad() val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes) if val_loss + 0.01 < best_val_loss: best_val_loss = val_loss del best_model best_model = copy.deepcopy(susage) best_val_index = epoch best_val_cnt = 0 cur_test_loss = val_loss best_val_cnt += 1 loss_train.append(cur_train_loss) loss_test.append(cur_test_loss) # print progress print('Epoch: ', epoch, '| train loss: %.8f' % cur_train_loss, '| test loss: %.8f' % cur_test_loss) if best_val_cnt > 10: break f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes) return best_model, loss_train[: best_val_index], loss_test[: best_val_index], loss_train_all, f1_score_test, grad_variance_all
def sgcn(feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes, args, device, calculate_grad_vars=False, full_batch=False): # use multiprocess sample data process_ids = np.arange(args.batch_num) pool = mp.Pool(args.pool_num) lap_matrix_sq = lap_matrix.multiply(lap_matrix) jobs = prepare_data(pool, sampler, process_ids, train_nodes, samp_num_list, len(feat_data), lap_matrix, lap_matrix_sq, args.n_layers) susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes, layers=args.n_layers, dropout=args.dropout).to(device) susage.to(device) print(susage) adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler( train_nodes, len(feat_data), lap_matrix, args.n_layers) adjs_full = package_mxl(adjs_full, device) optimizer = optim.SGD(filter(lambda p: p.requires_grad, susage.parameters()), lr=args.lr) loss_train = [] loss_test = [] grad_variance_all = [] loss_train_all = [] best_model = copy.deepcopy(susage) best_val_loss = 10 # randomly pick a large number is fine best_val_index = 0 best_val_cnt = 0 for epoch in np.arange(args.epoch_num): # fetch train data train_data = [job.get() for job in jobs] pool.close() pool.join() # prepare next epoch train data pool = mp.Pool(args.pool_num) jobs = prepare_data(pool, sampler, process_ids, train_nodes, samp_num_list, len(feat_data), lap_matrix, lap_matrix_sq, args.n_layers) # it can also run full-batch GD by ignoring all the samplings if full_batch: inner_loop_num = args.batch_num cur_train_loss, cur_train_loss_all, grad_variance = full_step( susage, optimizer, feat_data, labels, train_nodes, valid_nodes, adjs_full, train_data, inner_loop_num, device, calculate_grad_vars=calculate_grad_vars) else: inner_loop_num = args.batch_num cur_train_loss, cur_train_loss_all, grad_variance = sgd_step( susage, optimizer, feat_data, labels, train_nodes, valid_nodes, adjs_full, train_data, inner_loop_num, device, calculate_grad_vars=calculate_grad_vars) loss_train_all.extend(cur_train_loss_all) grad_variance_all.extend(grad_variance) # calculate test loss susage.eval() susage.zero_grad() val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes) if val_loss + 5e-4 < best_val_loss: best_val_loss = val_loss del best_model best_model = copy.deepcopy(susage) best_val_index = epoch best_val_cnt = 0 cur_test_loss = val_loss best_val_cnt += 1 loss_train.append(cur_train_loss) loss_test.append(cur_test_loss) # print progress print('Epoch: ', epoch, '| train loss: %.8f' % cur_train_loss, '| test loss: %.8f' % cur_test_loss) if best_val_cnt > 10: break f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes) return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, best_val_index