Beispiel #1
0
def sgcn_plus_v2(feat_data,
                 labels,
                 lap_matrix,
                 train_nodes,
                 valid_nodes,
                 test_nodes,
                 args,
                 device,
                 calculate_grad_vars=False):

    # use multiprocess sample data
    process_ids = np.arange(args.batch_num)
    pool = mp.Pool(args.pool_num)
    lap_matrix_sq = lap_matrix.multiply(lap_matrix)
    jobs = prepare_data(pool, sampler, process_ids, train_nodes, samp_num_list,
                        len(feat_data), lap_matrix, lap_matrix_sq,
                        args.n_layers)

    susage = GCN(nfeat=feat_data.shape[1],
                 nhid=args.nhid,
                 num_classes=num_classes,
                 layers=args.n_layers,
                 dropout=args.dropout).to(device)
    susage.to(device)

    print(susage)

    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(
        train_nodes, len(feat_data), lap_matrix, args.n_layers)
    adjs_full = package_mxl(adjs_full, device)

    # this stupid wrapper is only used for sgcn++
    forward_wrapper = ForwardWrapperMomentum(len(feat_data), args.nhid,
                                             args.n_layers, num_classes)

    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 susage.parameters()),
                          lr=0.1)

    loss_train = []
    loss_test = []
    grad_variance_all = []
    loss_train_all = []

    best_model = copy.deepcopy(susage)
    best_val_loss = 10  # randomly pick a large number is fine
    best_val_index = 0
    best_val_cnt = 0

    for epoch in np.arange(args.epoch_num):
        # fetch train data
        train_data = [job.get() for job in jobs]
        pool.close()
        pool.join()
        # prepare next epoch train data
        pool = mp.Pool(args.pool_num)
        jobs = prepare_data(pool, sampler,
                            process_ids, train_nodes, samp_num_list,
                            len(feat_data), lap_matrix, lap_matrix_sq,
                            args.n_layers)

        inner_loop_num = args.batch_num
        # compare with sgcn_plus, the only difference is we use multi_level_spider_step_v1 here
        cur_train_loss, cur_train_loss_all, grad_variance = multi_level_momentum_step(
            susage,
            optimizer,
            feat_data,
            labels,
            train_nodes,
            valid_nodes,
            adjs_full,
            sampled_nodes_full,
            train_data,
            inner_loop_num,
            forward_wrapper,
            device,
            calculate_grad_vars=calculate_grad_vars)
        loss_train_all.extend(cur_train_loss_all)
        grad_variance_all.extend(grad_variance)
        # calculate validate loss
        susage.eval()

        susage.zero_grad()
        val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels,
                                                 valid_nodes)

        if val_loss + 0.01 < best_val_loss:
            best_val_loss = val_loss
            del best_model
            best_model = copy.deepcopy(susage)
            best_val_index = epoch
            best_val_cnt = 0

        cur_test_loss = val_loss
        best_val_cnt += 1

        loss_train.append(cur_train_loss)
        loss_test.append(cur_test_loss)

        # print progress
        print('Epoch: ', epoch, '| train loss: %.8f' % cur_train_loss,
              '| test loss: %.8f' % cur_test_loss)

        if best_val_cnt > 10:
            break

    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels,
                                            test_nodes)
    return best_model, loss_train[:
                                  best_val_index], loss_test[:
                                                             best_val_index], loss_train_all, f1_score_test, grad_variance_all
Beispiel #2
0
def sgcn(feat_data,
         labels,
         lap_matrix,
         train_nodes,
         valid_nodes,
         test_nodes,
         args,
         device,
         calculate_grad_vars=False,
         full_batch=False):

    # use multiprocess sample data
    process_ids = np.arange(args.batch_num)
    pool = mp.Pool(args.pool_num)
    lap_matrix_sq = lap_matrix.multiply(lap_matrix)
    jobs = prepare_data(pool, sampler, process_ids, train_nodes, samp_num_list,
                        len(feat_data), lap_matrix, lap_matrix_sq,
                        args.n_layers)

    susage = GCN(nfeat=feat_data.shape[1],
                 nhid=args.nhid,
                 num_classes=num_classes,
                 layers=args.n_layers,
                 dropout=args.dropout).to(device)
    susage.to(device)

    print(susage)

    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(
        train_nodes, len(feat_data), lap_matrix, args.n_layers)
    adjs_full = package_mxl(adjs_full, device)

    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 susage.parameters()),
                          lr=args.lr)

    loss_train = []
    loss_test = []
    grad_variance_all = []
    loss_train_all = []

    best_model = copy.deepcopy(susage)
    best_val_loss = 10  # randomly pick a large number is fine
    best_val_index = 0
    best_val_cnt = 0

    for epoch in np.arange(args.epoch_num):
        # fetch train data
        train_data = [job.get() for job in jobs]
        pool.close()
        pool.join()
        # prepare next epoch train data
        pool = mp.Pool(args.pool_num)
        jobs = prepare_data(pool, sampler,
                            process_ids, train_nodes, samp_num_list,
                            len(feat_data), lap_matrix, lap_matrix_sq,
                            args.n_layers)

        # it can also run full-batch GD by ignoring all the samplings
        if full_batch:
            inner_loop_num = args.batch_num
            cur_train_loss, cur_train_loss_all, grad_variance = full_step(
                susage,
                optimizer,
                feat_data,
                labels,
                train_nodes,
                valid_nodes,
                adjs_full,
                train_data,
                inner_loop_num,
                device,
                calculate_grad_vars=calculate_grad_vars)
        else:
            inner_loop_num = args.batch_num
            cur_train_loss, cur_train_loss_all, grad_variance = sgd_step(
                susage,
                optimizer,
                feat_data,
                labels,
                train_nodes,
                valid_nodes,
                adjs_full,
                train_data,
                inner_loop_num,
                device,
                calculate_grad_vars=calculate_grad_vars)
        loss_train_all.extend(cur_train_loss_all)
        grad_variance_all.extend(grad_variance)
        # calculate test loss
        susage.eval()

        susage.zero_grad()
        val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels,
                                                 valid_nodes)

        if val_loss + 5e-4 < best_val_loss:
            best_val_loss = val_loss
            del best_model
            best_model = copy.deepcopy(susage)
            best_val_index = epoch
            best_val_cnt = 0

        cur_test_loss = val_loss
        best_val_cnt += 1

        loss_train.append(cur_train_loss)
        loss_test.append(cur_test_loss)

        # print progress
        print('Epoch: ', epoch, '| train loss: %.8f' % cur_train_loss,
              '| test loss: %.8f' % cur_test_loss)

        if best_val_cnt > 10:
            break
    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels,
                                            test_nodes)
    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, best_val_index