Example #1
0
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator,
        n_running):
    evaluator_wrapper = lambda pred, labels: evaluator.eval({
        "y_pred": pred,
        "y_true": labels
    })["rocauc"]

    train_batch_size = (len(train_idx) + 9) // 10
    # batch_size = len(train_idx)
    train_sampler = MultiLayerNeighborSampler(
        [16 for _ in range(args.n_layers)])
    # sampler = MultiLayerFullNeighborSampler(args.n_layers)
    train_dataloader = DataLoaderWrapper(
        NodeDataLoader(
            graph.cpu(),
            train_idx.cpu(),
            train_sampler,
            batch_sampler=BatchSampler(len(train_idx),
                                       batch_size=train_batch_size),
            num_workers=4,
        ))

    eval_sampler = MultiLayerNeighborSampler(
        [60 for _ in range(args.n_layers)])
    # sampler = MultiLayerFullNeighborSampler(args.n_layers)
    eval_dataloader = DataLoaderWrapper(
        NodeDataLoader(
            graph.cpu(),
            torch.cat([train_idx.cpu(),
                       val_idx.cpu(),
                       test_idx.cpu()]),
            eval_sampler,
            batch_sampler=BatchSampler(graph.number_of_nodes(),
                                       batch_size=32768),
            num_workers=4,
        ))

    criterion = nn.BCEWithLogitsLoss()

    model = gen_model(args).to(device)

    optimizer = optim.AdamW(model.parameters(),
                            lr=args.lr,
                            weight_decay=args.wd)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                        mode="max",
                                                        factor=0.75,
                                                        patience=50,
                                                        verbose=True)

    total_time = 0
    val_score, best_val_score, final_test_score = 0, 0, 0

    train_scores, val_scores, test_scores = [], [], []
    losses, train_losses, val_losses, test_losses = [], [], [], []
    final_pred = None

    for epoch in range(1, args.n_epochs + 1):
        tic = time.time()

        loss = train(args, model, train_dataloader, labels, train_idx,
                     criterion, optimizer, evaluator_wrapper)

        toc = time.time()
        total_time += toc - tic

        if epoch == args.n_epochs or epoch % args.eval_every == 0 or epoch % args.log_every == 0:
            train_score, val_score, test_score, train_loss, val_loss, test_loss, pred = evaluate(
                args, model, eval_dataloader, labels, train_idx, val_idx,
                test_idx, criterion, evaluator_wrapper)

            if val_score > best_val_score:
                best_val_score = val_score
                final_test_score = test_score
                final_pred = pred

            if epoch % args.log_every == 0:
                print(
                    f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}s"
                )
                print(
                    f"Loss: {loss:.4f}\n"
                    f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
                    f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}"
                )

            for l, e in zip(
                [
                    train_scores, val_scores, test_scores, losses,
                    train_losses, val_losses, test_losses
                ],
                [
                    train_score, val_score, test_score, loss, train_loss,
                    val_loss, test_loss
                ],
            ):
                l.append(e)

        lr_scheduler.step(val_score)

    print("*" * 50)
    print(
        f"Best val score: {best_val_score}, Final test score: {final_test_score}"
    )
    print("*" * 50)

    if args.plot_curves:
        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.set_yticks(np.linspace(0, 1.0, 101))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip([train_scores, val_scores, test_scores],
                            ["train score", "val score", "test score"]):
            plt.plot(range(1, args.n_epochs + 1, args.log_every),
                     y,
                     label=label,
                     linewidth=1)
        ax.xaxis.set_major_locator(MultipleLocator(100))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.01))
        ax.yaxis.set_minor_locator(AutoMinorLocator(2))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gat_score_{n_running}.png")

        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip([losses, train_losses, val_losses, test_losses],
                            ["loss", "train loss", "val loss", "test loss"]):
            plt.plot(range(1, args.n_epochs + 1, args.log_every),
                     y,
                     label=label,
                     linewidth=1)
        ax.xaxis.set_major_locator(MultipleLocator(100))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.1))
        ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gat_loss_{n_running}.png")

    if args.save_pred:
        os.makedirs("./output", exist_ok=True)
        torch.save(F.softmax(final_pred, dim=1), f"./output/{n_running}.pt")

    return best_val_score, final_test_score
Example #2
0
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator,
        n_running):
    evaluator_wrapper = lambda pred, labels: evaluator.eval(
        {
            "y_pred": pred.argmax(dim=-1, keepdim=True),
            "y_true": labels
        })["acc"]
    criterion = custom_loss_function

    n_train_samples = train_idx.shape[0]
    train_batch_size = (n_train_samples + 29) // 30
    train_sampler = MultiLayerNeighborSampler(
        [10 for _ in range(args.n_layers)])
    train_dataloader = DataLoaderWrapper(
        DataLoader(
            graph.cpu(),
            train_idx.cpu(),
            train_sampler,
            batch_sampler=BatchSampler(len(train_idx),
                                       batch_size=train_batch_size,
                                       shuffle=True),
            num_workers=4,
        ))

    eval_batch_size = 32768
    eval_sampler = MultiLayerNeighborSampler(
        [15 for _ in range(args.n_layers)])

    if args.estimation_mode:
        test_idx_during_training = test_idx[torch.arange(start=0,
                                                         end=len(test_idx),
                                                         step=45)]
    else:
        test_idx_during_training = test_idx

    eval_idx = torch.cat(
        [train_idx.cpu(),
         val_idx.cpu(),
         test_idx_during_training.cpu()])
    eval_dataloader = DataLoaderWrapper(
        DataLoader(
            graph.cpu(),
            eval_idx,
            eval_sampler,
            batch_sampler=BatchSampler(len(eval_idx),
                                       batch_size=eval_batch_size,
                                       shuffle=False),
            num_workers=4,
        ))

    model = gen_model(args).to(device)

    optimizer = optim.AdamW(model.parameters(),
                            lr=args.lr,
                            weight_decay=args.wd)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                        mode="max",
                                                        factor=0.7,
                                                        patience=20,
                                                        verbose=True,
                                                        min_lr=1e-4)

    best_model_state_dict = None

    total_time = 0
    val_score, best_val_score, final_test_score = 0, 0, 0

    scores, train_scores, val_scores, test_scores = [], [], [], []
    losses, train_losses, val_losses, test_losses = [], [], [], []

    for epoch in range(1, args.n_epochs + 1):
        tic = time.time()

        score, loss = train(args, model, train_dataloader, labels, train_idx,
                            criterion, optimizer, evaluator_wrapper)

        toc = time.time()
        total_time += toc - tic

        if epoch == args.n_epochs or epoch % args.eval_every == 0 or epoch % args.log_every == 0:
            train_score, val_score, test_score, train_loss, val_loss, test_loss = evaluate(
                args,
                model,
                eval_dataloader,
                labels,
                train_idx,
                val_idx,
                test_idx_during_training,
                criterion,
                evaluator_wrapper,
            )

            if val_score > best_val_score:
                best_val_score = val_score
                final_test_score = test_score
                if args.estimation_mode:
                    best_model_state_dict = {
                        k: v.to("cpu")
                        for k, v in model.state_dict().items()
                    }

            if epoch == args.n_epochs or epoch % args.log_every == 0:
                print(
                    f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2s}\n"
                    f"Loss: {loss:.4f}, Score: {score:.4f}\n"
                    f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
                    f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}"
                )

            for l, e in zip(
                [
                    scores, train_scores, val_scores, test_scores, losses,
                    train_losses, val_losses, test_losses
                ],
                [
                    score, train_score, val_score, test_score, loss,
                    train_loss, val_loss, test_loss
                ],
            ):
                l.append(e)

        lr_scheduler.step(val_score)

    if args.estimation_mode:
        model.load_state_dict(best_model_state_dict)
        eval_dataloader = DataLoaderWrapper(
            DataLoader(
                graph.cpu(),
                test_idx.cpu(),
                eval_sampler,
                batch_sampler=BatchSampler(len(test_idx),
                                           batch_size=eval_batch_size,
                                           shuffle=False),
                num_workers=4,
            ))
        final_test_score = evaluate(args, model, eval_dataloader, labels,
                                    train_idx, val_idx, test_idx, criterion,
                                    evaluator_wrapper)[2]

    print("*" * 50)
    print(
        f"Best val score: {best_val_score}, Final test score: {final_test_score}"
    )
    print("*" * 50)

    if args.plot_curves:
        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.set_yticks(np.linspace(0, 1.0, 101))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip([train_scores, val_scores, test_scores],
                            ["train score", "val score", "test score"]):
            plt.plot(range(1, args.n_epochs + 1, args.log_every),
                     y,
                     label=label,
                     linewidth=1)
        ax.xaxis.set_major_locator(MultipleLocator(10))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.01))
        ax.yaxis.set_minor_locator(AutoMinorLocator(2))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gat_score_{n_running}.png")

        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip([losses, train_losses, val_losses, test_losses],
                            ["loss", "train loss", "val loss", "test loss"]):
            plt.plot(range(1, args.n_epochs + 1, args.log_every),
                     y,
                     label=label,
                     linewidth=1)
        ax.xaxis.set_major_locator(MultipleLocator(10))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.1))
        ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gat_loss_{n_running}.png")

    return best_val_score, final_test_score