コード例 #1
0
def main(_args=None):
    args = create_parser().parse_args(_args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = load_dataset(args.dataset)
    data = dataset[0]

    model_name = args.model.upper()
    if model_name == "SGC" and args.dataset.lower() in ["cs", "physics"]:
        model_name = "USGC"
        print(model_name, end=" ")
    model_cls = globals()[model_name]
    data = data.to(device)
    edges = data.edge_index.to(device)

    train_mask, val_mask, test_mask = load_split(
        f"{args.splits_dir}/{args.dataset.lower()}_{args.split_id}.mask")
    epochs = EPOCHS_CONFIG[args.dataset.lower()]

    print("Preparing Model...")
    torch.manual_seed(args.pre_seed)
    model = model_cls(dataset.num_features, dataset.num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=0.01,
                                 weight_decay=5e-4)

    for _ in range(epochs):
        train(model, optimizer, data, edges, train_mask)
    print("Done! ", test(model, data, edges, train_mask, val_mask, test_mask))
    sys.stdout.flush()

    y = F.softmax(model(data.x, edges), dim=1)

    modifier_cls = globals()["Edges" + args.mode.capitalize()]
    modifier = modifier_cls(y, edges, device, normalization=args.normalization)

    all_thresholds = [0.01 * x for x in range(101)]

    dataset_dir = os.path.join(args.out_dir, args.dataset.lower(),
                               args.mode.lower())
    if not os.path.isdir(dataset_dir):
        os.makedirs(dataset_dir)

    adj_m = None
    for threshold in all_thresholds:
        print(f"{args.dataset} {args.model} {args.split_id} {threshold:.2f}",
              end=" ")
        save_path = os.path.join(
            dataset_dir,
            f"{args.model.lower()}_{args.split_id}_{threshold:.2f}.edges")
        adj_m = modifier.modify(threshold, adj_m)
        new_edges = adj_m.nonzero().t().to(device)
        torch.save(new_edges, save_path)
        print("|", datetime.now(), new_edges.size(1))
コード例 #2
0
def test_loader():
    batch_size = 16
    data_path = './data/nyu_depth_v2_labeled.mat'
    # 1.Load data
    train_lists, val_lists, test_lists = load_split()

    print("Loading data...")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, train_lists),
                                               batch_size=batch_size, shuffle=True, drop_last=True)
    for input, depth in train_loader:
        print(input.size())
        break
    #input_rgb_image = input[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    input_rgb_image = input[0].data.permute(1, 2, 0)
    input_gt_depth_image = depth[0][0].data.cpu().numpy().astype(np.float32)

    input_gt_depth_image /= np.max(input_gt_depth_image)
    plt.imshow(input_rgb_image)
    plt.show()
    plt.imshow(input_gt_depth_image, cmap="viridis")
    plt.show()
コード例 #3
0
def main(_args=None):
    args = create_parser().parse_args(_args)

    epochs = EPOCHS_CONFIG[args.dataset.lower()]
    num_seeds = args.num_seeds
    num_splits = args.num_splits

    model_name = args.model.upper()
    if model_name == "SGC" and args.dataset.lower() in ["cs", "physics"]:
        model_name = "USGC"
        print(model_name, end=" ")
    if (model_name == "SGC" and args.mode == "adder"
            and args.dataset.lower() in ["computers"]):
        model_name = "USGC"
        print(model_name, end=" ")
    model_cls = globals()[model_name]

    num_thresholds = 101
    if model_name == "GAT" and args.dataset == "physics":
        num_thresholds = 60

    seeds = list(range(num_seeds))
    splits = list(range(num_splits))

    dataset = load_dataset(args.dataset)
    data = dataset[0]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    edges = data.edge_index.to(device)

    if args.output is None:
        output = f"{args.out_dir}/{args.model.lower()}_{args.dataset.lower()}_{args.mode}_{num_splits}_{num_seeds}.np"
    elif args.output == "time":
        output = "{time}.np".format(
            time=datetime.now().strftime(r"%Y-%m-%d_%H:%M:%S"))
    else:
        output = args.output

    result = np.zeros(
        (3, 101, num_splits, num_seeds, epochs))  # 101: num_of_thresholds

    if args.resume is not None:
        result = load_data(output)
        splits = list(range(args.resume, num_splits))

    start_time = datetime.now()
    for split in splits:
        train_mask, val_mask, test_mask = load_split(
            f"{args.splits_dir}/{args.dataset.lower()}_{split}.mask")
        for threshold_i in range(num_thresholds):
            threshold = 0.01 * threshold_i
            for seed in seeds:
                torch.manual_seed(seed)  # set seed here

                edges = torch.load(
                    os.path.join(
                        args.edges_dir,
                        args.dataset.lower(),
                        args.mode,
                        f"{args.model.lower()}_{split}_{threshold:.2f}.edges",
                    ))
                model = model_cls(dataset.num_features,
                                  dataset.num_classes).to(device)
                data = data.to(device)
                optimizer = torch.optim.Adam(model.parameters(),
                                             lr=0.01,
                                             weight_decay=5e-4)

                best_val_acc = test_acc = 0
                for epoch in range(epochs):
                    train(model, optimizer, data, edges, train_mask)
                    train_acc, val_acc, tmp_test_acc = test(
                        model, data, edges, train_mask, val_mask, test_mask)
                    if val_acc > best_val_acc:
                        best_val_acc = val_acc
                        test_acc = tmp_test_acc
                    log = "Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}"

                    result[0][threshold_i][split][seed][epoch] = train_acc
                    result[1][threshold_i][split][seed][epoch] = best_val_acc
                    result[2][threshold_i][split][seed][epoch] = test_acc

                    if args.verbose:
                        print(
                            f"{datetime.now()}",
                            log.format(epoch + 1, train_acc, best_val_acc,
                                       test_acc),
                        )

                    if epoch == epochs - 1:
                        now = datetime.now()
                        print(
                            f"{now} {args.model.lower()} split: {split} threshold: {threshold:.2f} seed: {seed}",
                            log.format(epoch + 1, train_acc, best_val_acc,
                                       test_acc),
                            f"Elapsed: {now - start_time}",
                        )
                sys.stdout.flush()
        # per-split
        with open(output, "wb") as f:
            pickle.dump(result, f)

    print(result)
コード例 #4
0
ファイル: main.py プロジェクト: yang-han/P-reg
def run():
    # hyperhyparamter parse
    args = create_parser().parse_args()
    epochs = args.epochs
    num_seeds = args.num_seeds
    mu = args.mu
    patience = args.patience
    lr = args.lr
    weight_decay = args.weight_decay
    num_splits = args.num_splits

    print(args)
    if num_splits == 1:
        print("Running using the standard split...")
    else:
        print("Running using {} random splits...".format(num_splits))

    if args.model.upper() in [
            'PREGGCN', 'PREGGAT', 'PREGMLP', 'PREGGAT_PUBMED'
    ]:
        model_cls = globals()[args.model.upper()]
    else:
        raise NotImplementedError("model selection error")

    seeds = list(range(num_seeds))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = load_dataset(args.dataset, T.NormalizeFeatures())
    data = dataset[0]

    result = np.zeros((4, num_seeds, num_splits))

    # If you want to change the data split, change the path_split to your own split
    # Attention: If you set num_splits=1, the codes will use the Plantoid standrad split.
    path_split = "splits"

    # For each split
    for split in range(num_splits):
        if num_splits == 1:
            # Using the standard split
            splits = data.train_mask, data.val_mask, data.test_mask
        else:
            splits = load_split(
                os.path.join(path_split,
                             args.dataset.lower() + '_' + str(split) +
                             '.mask'))
        # For each split, run num_seeds times
        for seed in seeds:
            torch.manual_seed(seed)
            model = model_cls(dataset.num_features,
                              dataset.num_classes).to(device)

            data = data.to(device)
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=lr,
                                         weight_decay=weight_decay)

            best_val_acc = test_acc = best_val_train_acc = best_val_test_acc = 0.
            best_val_epoch = 0
            cnt_wait = 0

            for epoch in range(epochs):
                train(model, optimizer, data, splits, mu)

                train_acc, val_acc, test_acc = test(model, data, splits)

                if val_acc > best_val_acc:
                    cnt_wait = 0

                if val_acc >= best_val_acc:
                    best_val_train_acc = train_acc
                    best_val_acc = val_acc
                    best_val_test_acc = test_acc
                    best_val_epoch = epoch
                else:
                    cnt_wait += 1

                if cnt_wait > patience:
                    break

            result[0][seed][split] = best_val_train_acc
            result[1][seed][split] = best_val_acc
            result[2][seed][split] = best_val_test_acc
            result[3][seed][split] = best_val_epoch
            print('seed:', seed, 'Epoch:', best_val_epoch, 'Train Acc:',
                  best_val_train_acc, 'Val Acc:', best_val_acc, 'Test Acc:',
                  best_val_test_acc)

    # summarize and store the result.
    path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    symbol = "\u00b1"
    data_avr = np.mean(result, axis=(1, 2))
    data_std = np.std(result, axis=(1, 2))
    log2 = "Result: Dataset: {}, Model: {}, Avg. Epochs: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.5f}{}{:.5f}"
    print(
        log2.format(dataset, model_cls, int(data_avr[3]), data_avr[0],
                    data_avr[1], data_avr[2], symbol, data_std[2]))
    para = '{:.02f}'.format(mu) + '_' + str(lr) + '_' + str(
        weight_decay) + '_' + str(patience)
    outfile = args.dataset.lower() + '_' + para + '.npy'
    if num_splits == 1:
        if not os.path.exists(
                os.path.join(path, "results",
                             args.model.lower() + "_stand")):
            os.makedirs(
                os.path.join(path, "results",
                             args.model.lower() + "_stand"))
        with open(
                os.path.join(path, "results",
                             args.model.lower() + "_stand", outfile),
                'wb') as f:
            np.save(f, result)
    else:
        if not os.path.exists(os.path.join(path, "results",
                                           args.model.lower())):
            os.makedirs(os.path.join(path, "results", args.model.lower()))
        with open(os.path.join(path, "results", args.model.lower(), outfile),
                  'wb') as f:
            np.save(f, result)
コード例 #5
0
def main():
    batch_size = 16
    data_path = './data/nyu_depth_v2_labeled.mat'
    learning_rate = 1.0e-4
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = 100

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data...")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=False,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)
    print(train_loader)
    # 2.Load model
    print("Loading model...")
    model = FCRN(batch_size)
    model.load_state_dict(load_weights(model, weights_file,
                                       dtype))  #加载官方参数,从tensorflow转过来
    #加载训练模型
    resume_from_file = False
    resume_file = './model/model_300.pth'
    if resume_from_file:
        if os.path.isfile(resume_file):
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("can not find!")
    model = model.cuda()

    # 3.Loss
    # 官方MSE
    # loss_fn = torch.nn.MSELoss()
    # 自定义MSE
    # loss_fn = loss_mse()
    # 论文的loss,the reverse Huber
    loss_fn = loss_huber()
    print("loss_fn set...")

    # 4.Optim
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    print("optimizer set...")

    # 5.Train
    best_val_err = 1.0e-4
    start_epoch = 0

    for epoch in range(num_epochs):
        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs + start_epoch))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0
        for input, depth in train_loader:

            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)
            loss = loss_fn(output, depth_var)
            print('loss: %f' % loss.data.cpu().item())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:
                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)
                if num_epochs == epoch + 1:
                    # 关于保存的测试图片可以参考 loader 的写法
                    # input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
                    input_rgb_image = input[0].data.permute(1, 2, 0)
                    input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                    ).astype(np.float32)
                    pred_depth_image = output[0].data.squeeze().cpu().numpy(
                    ).astype(np.float32)

                    input_gt_depth_image /= np.max(input_gt_depth_image)
                    pred_depth_image /= np.max(pred_depth_image)

                    plot.imsave(
                        './result/input_rgb_epoch_{}.png'.format(start_epoch +
                                                                 epoch + 1),
                        input_rgb_image)
                    plot.imsave(
                        './result/gt_depth_epoch_{}.png'.format(start_epoch +
                                                                epoch + 1),
                        input_gt_depth_image,
                        cmap="viridis")
                    plot.imsave(
                        './result/pred_depth_epoch_{}.png'.format(start_epoch +
                                                                  epoch + 1),
                        pred_depth_image,
                        cmap="viridis")

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

        err = float(loss_local) / num_samples
        print('val_error: %f' % err)

        if err < best_val_err or epoch == num_epochs - 1:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, './model/model_' + str(start_epoch + epoch + 1) + '.pth')

        if epoch % 10 == 0:
            learning_rate = learning_rate * 0.8
コード例 #6
0
def _main(_args=None):
    args = create_parser().parse_args(_args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    epochs = EPOCHS_CONFIG[args.dataset.lower()]

    model_name = args.model.upper()
    if model_name == "SGC" and args.dataset.lower() in ["cs", "physics"]:
        model_name = "USGC"
        print(model_name, end=" ")

    dataset = load_dataset(args.dataset.lower())
    data = dataset[0]
    num_nodes = data.num_nodes
    num_classes = dataset.num_classes
    # num_thres = args.num_thres
    edges = data.edge_index.to(device)
    data = data.to(device)
    model_cls = globals()[model_name]
    # print("num_nodes", data.num_nodes)
    o_train_mask, o_val_mask, o_test_mask = load_split(
        f"{args.splits_dir}/{args.dataset.lower()}_{args.split_id}.mask")

    # o_seen_mask = o_train_mask + o_val_mask
    # seen_index = o_seen_mask.nonzero()

    o_y = data.y.clone()
    # seeds = list(2 ** s - 1 for s in range(args.num_models, 0, -1))
    model_seeds = list(range(args.num_models))
    seeds = list(range(args.num_seeds))

    num_models = args.num_models
    models = list()

    logits = torch.zeros((num_models, num_nodes, num_classes), device=device)
    preds = torch.zeros((num_models, num_nodes),
                        dtype=torch.int,
                        device=device)
    confs = torch.zeros((num_models, num_nodes),
                        dtype=torch.float16,
                        device=device)
    nodes = torch.zeros((num_models, num_nodes), dtype=torch.long)

    i = 0
    # print(now_train_mask.sum().item(), now_val_mask.sum().item())
    models.append(
        train_model(
            model_cls,
            dataset,
            o_y,
            edges,
            o_train_mask,
            o_val_mask,
            o_test_mask,
            epochs,
            device,
            model_seeds[i],
            end="\n",
        ))

    logits[i] = F.softmax(models[i](data.x, edges), dim=1).detach()
    l = logits[i].max(dim=1)
    preds[i] = l.indices
    confs[i] = l.values.to(device, dtype=torch.float16)
    nodes[i] = confs[i].argsort(descending=True)

    i = 1
    # print(now_train_mask.sum().item(), now_val_mask.sum().item())
    models.append(
        train_model(
            model_cls,
            dataset,
            o_y,
            edges,
            o_val_mask,
            o_train_mask,
            o_test_mask,
            epochs,
            device,
            model_seeds[i],
            end="\n",
        ))

    logits[i] = F.softmax(models[i](data.x, edges), dim=1).detach()
    l = logits[i].max(dim=1)
    preds[i] = l.indices
    confs[i] = l.values.to(device, dtype=torch.float16)
    nodes[i] = confs[i].argsort(descending=True)
    del models

    add_per_c_num = -1
    for i in range(args.s_thres, args.e_thres, args.j_thres):
        new_y = o_y.clone()
        new_train_mask = o_train_mask.clone()
        pre_add_train_mask = torch.zeros_like(o_train_mask, dtype=torch.bool)
        num_pre_selected = num_nodes * i // 100
        pre_selected = nodes[:, :num_pre_selected]
        print(f"{i} pre_selected: {num_pre_selected}", end=" ")
        num_added = 0
        inconsistent = 0
        num_yes = 0
        num_no = 0
        selected = pre_selected[0]
        for j in range(1, num_models):
            selected = np.intersect1d(selected, pre_selected[j])
        selected = torch.from_numpy(selected)
        print("selected:", selected.size(0), end=" ")
        all_preds = preds[:, selected]
        for col in range(all_preds.size(1)):
            if not o_test_mask[selected[col]]:
                continue
            base_pred = all_preds[0, col]
            con_flag = True
            for row in range(1, num_models):
                now_pred = all_preds[row, col]
                # print(now_pred, base_pred)
                if now_pred.item() != base_pred.item():
                    inconsistent += 1
                    con_flag = False
                    break
            if con_flag:
                pre_add_train_mask[selected[col]] = True
        base_preds = preds[0]
        add_c_num = list()
        for c in range(num_classes):
            add_c_num.append(
                (base_preds[pre_add_train_mask] == c).sum().item())
        last_add_per_c_num = add_per_c_num
        add_per_c_num = min(add_c_num)
        if add_per_c_num == 0 and i != args.s_thres:
            print(add_c_num, "skipping for not improving")
            continue

        if last_add_per_c_num == add_per_c_num:
            print(add_c_num, "skipping for not improving")
            continue

        remain_per_c_num = [add_per_c_num for _ in range(num_classes)]

        for node in selected:
            base_pred = base_preds[node].item()
            if pre_add_train_mask[node] and remain_per_c_num[base_pred] > 0:
                new_train_mask[node] = True
                num_added += 1
                remain_per_c_num[base_pred] -= 1
                new_y[node] = base_preds[node]
                if base_pred == o_y[node].item():
                    num_yes += 1
                else:
                    num_no += 1
        print(
            f"num_added: {num_added} inconsistent: {inconsistent} per_c: {add_c_num} | add_per_c: {add_per_c_num} num_yes/no: {num_yes} {num_no}",
            end=" ",
        )
        for seed_idx in range(args.num_seeds):
            print(" $ ", end="")
            train_model(
                model_cls,
                dataset,
                new_y,
                edges,
                new_train_mask,
                o_val_mask,
                o_test_mask,
                epochs,
                device,
                seed=seeds[seed_idx],
            )
        print()
コード例 #7
0
def _main(_args=None):
    args = create_parser().parse_args(_args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    epochs = EPOCHS_CONFIG[args.dataset.lower()]

    model_name = args.model.upper()
    if model_name == "SGC" and args.dataset.lower() in ["cs", "physics"]:
        model_name = "USGC"
        print(model_name, end=" ")

    dataset = load_dataset(args.dataset.lower())
    data = dataset[0]
    num_nodes = data.num_nodes
    num_classes = dataset.num_classes
    # num_thres = args.num_thres
    edges = data.edge_index.to(device)
    data = data.to(device)
    model_cls = globals()[model_name]
    # print("num_nodes", data.num_nodes)
    o_train_mask, o_val_mask, o_test_mask = load_split(
        f"{args.splits_dir}/{args.dataset.lower()}_{args.split_id}.mask")

    # o_seen_mask = o_train_mask + o_val_mask
    # seen_index = o_seen_mask.nonzero()

    o_y = data.y.clone()
    # seeds = list(2 ** s - 1 for s in range(args.num_models, 0, -1))
    model_seeds = list(range(args.num_models))
    seeds = list(range(args.num_seeds))

    num_models = args.num_models
    models = list()

    logits = torch.zeros((num_models, num_nodes, num_classes), device=device)
    preds = torch.zeros((num_models, num_nodes),
                        dtype=torch.int,
                        device=device)
    confs = torch.zeros((num_models, num_nodes),
                        dtype=torch.float16,
                        device=device)
    nodes = torch.zeros((num_models, num_nodes), dtype=torch.long)

    i = 0
    # print(now_train_mask.sum().item(), now_val_mask.sum().item())
    models.append(
        train_model(
            model_cls,
            dataset,
            o_y,
            edges,
            o_train_mask,
            o_val_mask,
            o_test_mask,
            epochs,
            device,
            model_seeds[i],
            end="\n",
        ))

    logits[i] = F.softmax(models[i](data.x, edges), dim=1).detach()
    l = logits[i].max(dim=1)
    preds[i] = l.indices
    confs[i] = l.values.to(device, dtype=torch.float16)
    nodes[i] = confs[i].argsort(descending=True)

    i = 1
    # print(now_train_mask.sum().item(), now_val_mask.sum().item())
    models.append(
        train_model(
            model_cls,
            dataset,
            o_y,
            edges,
            o_val_mask,
            o_train_mask,
            o_test_mask,
            epochs,
            device,
            model_seeds[i],
            end="\n",
        ))

    logits[i] = F.softmax(models[i](data.x, edges), dim=1).detach()
    l = logits[i].max(dim=1)
    preds[i] = l.indices
    confs[i] = l.values.to(device, dtype=torch.float16)
    nodes[i] = confs[i].argsort(descending=True)
    del models

    tna_thres = load_tna_thres(args.template, args.num_splits)
    tu_thres = load_tu_thres(args.filename)

    results = np.zeros((3, args.num_splits, args.num_seeds))
    for split_id in range(args.num_splits):
        for seed in range(args.num_seeds):
            print(
                split_id,
                seed,
                tna_thres[split_id, seed],
                tu_thres[split_id, seed],
                end=" ",
            )
            new_train_mask, new_y = enlarge(
                o_y,
                o_train_mask,
                o_test_mask,
                num_nodes,
                num_classes,
                num_models,
                nodes,
                preds,
                tna_thres[split_id, seed],
            )

            edges = torch.load(
                os.path.join(
                    "7_edges",
                    args.dataset.lower(),
                    "dropper2",
                    f"{args.model.lower()}_{args.split_id}_{0.01*tu_thres[split_id, seed]:.2f}.edges",
                ))
            train_acc, val_acc, test_acc = train_model_accs(
                model_cls,
                dataset,
                new_y,
                edges,
                new_train_mask,
                o_val_mask,
                o_test_mask,
                epochs,
                device,
                seed=seed,
                end="\n",
            )
            results[0, split_id, seed] = train_acc
            results[1, split_id, seed] = val_acc
            results[2, split_id, seed] = test_acc
    print(results[2].mean())
    print(results[2].var())
    # for seed_idx in range(args.num_seeds):
    #     print(" $ ", end="")
    #     train_model(
    #         model_cls,
    #         dataset,
    #         new_y,
    #         edges,
    #         new_train_mask,
    #         o_val_mask,
    #         o_test_mask,
    #         epochs,
    #         device,
    #         seed=seeds[seed_idx],
    #     )
    print()