def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running): evaluator_wrapper = lambda pred, labels: evaluator.eval({ "y_pred": pred, "y_true": labels })["rocauc"] train_batch_size = (len(train_idx) + 9) // 10 # batch_size = len(train_idx) train_sampler = MultiLayerNeighborSampler( [16 for _ in range(args.n_layers)]) # sampler = MultiLayerFullNeighborSampler(args.n_layers) train_dataloader = DataLoaderWrapper( NodeDataLoader( graph.cpu(), train_idx.cpu(), train_sampler, batch_sampler=BatchSampler(len(train_idx), batch_size=train_batch_size), num_workers=4, )) eval_sampler = MultiLayerNeighborSampler( [60 for _ in range(args.n_layers)]) # sampler = MultiLayerFullNeighborSampler(args.n_layers) eval_dataloader = DataLoaderWrapper( NodeDataLoader( graph.cpu(), torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]), eval_sampler, batch_sampler=BatchSampler(graph.number_of_nodes(), batch_size=32768), num_workers=4, )) criterion = nn.BCEWithLogitsLoss() model = gen_model(args).to(device) optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.75, patience=50, verbose=True) total_time = 0 val_score, best_val_score, final_test_score = 0, 0, 0 train_scores, val_scores, test_scores = [], [], [] losses, train_losses, val_losses, test_losses = [], [], [], [] final_pred = None for epoch in range(1, args.n_epochs + 1): tic = time.time() loss = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper) toc = time.time() total_time += toc - tic if epoch == args.n_epochs or epoch % args.eval_every == 0 or epoch % args.log_every == 0: train_score, val_score, test_score, train_loss, val_loss, test_loss, pred = evaluate( args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper) if val_score > best_val_score: best_val_score = val_score final_test_score = test_score final_pred = pred if epoch % args.log_every == 0: print( f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}s" ) print( f"Loss: {loss:.4f}\n" f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n" f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}" ) for l, e in zip( [ train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses ], [ train_score, val_score, test_score, loss, train_loss, val_loss, test_loss ], ): l.append(e) lr_scheduler.step(val_score) print("*" * 50) print( f"Best val score: {best_val_score}, Final test score: {final_test_score}" ) print("*" * 50) if args.plot_curves: fig = plt.figure(figsize=(24, 24)) ax = fig.gca() ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_yticks(np.linspace(0, 1.0, 101)) ax.tick_params(labeltop=True, labelright=True) for y, label in zip([train_scores, val_scores, test_scores], ["train score", "val score", "test score"]): plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1) ax.xaxis.set_major_locator(MultipleLocator(100)) ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.yaxis.set_major_locator(MultipleLocator(0.01)) ax.yaxis.set_minor_locator(AutoMinorLocator(2)) plt.grid(which="major", color="red", linestyle="dotted") plt.grid(which="minor", color="orange", linestyle="dotted") plt.legend() plt.tight_layout() plt.savefig(f"gat_score_{n_running}.png") fig = plt.figure(figsize=(24, 24)) ax = fig.gca() ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.tick_params(labeltop=True, labelright=True) for y, label in zip([losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]): plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1) ax.xaxis.set_major_locator(MultipleLocator(100)) ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.yaxis.set_major_locator(MultipleLocator(0.1)) ax.yaxis.set_minor_locator(AutoMinorLocator(5)) plt.grid(which="major", color="red", linestyle="dotted") plt.grid(which="minor", color="orange", linestyle="dotted") plt.legend() plt.tight_layout() plt.savefig(f"gat_loss_{n_running}.png") if args.save_pred: os.makedirs("./output", exist_ok=True) torch.save(F.softmax(final_pred, dim=1), f"./output/{n_running}.pt") return best_val_score, final_test_score
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running): evaluator_wrapper = lambda pred, labels: evaluator.eval( { "y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels })["acc"] criterion = custom_loss_function n_train_samples = train_idx.shape[0] train_batch_size = (n_train_samples + 29) // 30 train_sampler = MultiLayerNeighborSampler( [10 for _ in range(args.n_layers)]) train_dataloader = DataLoaderWrapper( DataLoader( graph.cpu(), train_idx.cpu(), train_sampler, batch_sampler=BatchSampler(len(train_idx), batch_size=train_batch_size, shuffle=True), num_workers=4, )) eval_batch_size = 32768 eval_sampler = MultiLayerNeighborSampler( [15 for _ in range(args.n_layers)]) if args.estimation_mode: test_idx_during_training = test_idx[torch.arange(start=0, end=len(test_idx), step=45)] else: test_idx_during_training = test_idx eval_idx = torch.cat( [train_idx.cpu(), val_idx.cpu(), test_idx_during_training.cpu()]) eval_dataloader = DataLoaderWrapper( DataLoader( graph.cpu(), eval_idx, eval_sampler, batch_sampler=BatchSampler(len(eval_idx), batch_size=eval_batch_size, shuffle=False), num_workers=4, )) model = gen_model(args).to(device) optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.7, patience=20, verbose=True, min_lr=1e-4) best_model_state_dict = None total_time = 0 val_score, best_val_score, final_test_score = 0, 0, 0 scores, train_scores, val_scores, test_scores = [], [], [], [] losses, train_losses, val_losses, test_losses = [], [], [], [] for epoch in range(1, args.n_epochs + 1): tic = time.time() score, loss = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper) toc = time.time() total_time += toc - tic if epoch == args.n_epochs or epoch % args.eval_every == 0 or epoch % args.log_every == 0: train_score, val_score, test_score, train_loss, val_loss, test_loss = evaluate( args, model, eval_dataloader, labels, train_idx, val_idx, test_idx_during_training, criterion, evaluator_wrapper, ) if val_score > best_val_score: best_val_score = val_score final_test_score = test_score if args.estimation_mode: best_model_state_dict = { k: v.to("cpu") for k, v in model.state_dict().items() } if epoch == args.n_epochs or epoch % args.log_every == 0: print( f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2s}\n" f"Loss: {loss:.4f}, Score: {score:.4f}\n" f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n" f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}" ) for l, e in zip( [ scores, train_scores, val_scores, test_scores, losses, train_losses, val_losses, test_losses ], [ score, train_score, val_score, test_score, loss, train_loss, val_loss, test_loss ], ): l.append(e) lr_scheduler.step(val_score) if args.estimation_mode: model.load_state_dict(best_model_state_dict) eval_dataloader = DataLoaderWrapper( DataLoader( graph.cpu(), test_idx.cpu(), eval_sampler, batch_sampler=BatchSampler(len(test_idx), batch_size=eval_batch_size, shuffle=False), num_workers=4, )) final_test_score = evaluate(args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper)[2] print("*" * 50) print( f"Best val score: {best_val_score}, Final test score: {final_test_score}" ) print("*" * 50) if args.plot_curves: fig = plt.figure(figsize=(24, 24)) ax = fig.gca() ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.set_yticks(np.linspace(0, 1.0, 101)) ax.tick_params(labeltop=True, labelright=True) for y, label in zip([train_scores, val_scores, test_scores], ["train score", "val score", "test score"]): plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1) ax.xaxis.set_major_locator(MultipleLocator(10)) ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.yaxis.set_major_locator(MultipleLocator(0.01)) ax.yaxis.set_minor_locator(AutoMinorLocator(2)) plt.grid(which="major", color="red", linestyle="dotted") plt.grid(which="minor", color="orange", linestyle="dotted") plt.legend() plt.tight_layout() plt.savefig(f"gat_score_{n_running}.png") fig = plt.figure(figsize=(24, 24)) ax = fig.gca() ax.set_xticks(np.arange(0, args.n_epochs, 100)) ax.tick_params(labeltop=True, labelright=True) for y, label in zip([losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]): plt.plot(range(1, args.n_epochs + 1, args.log_every), y, label=label, linewidth=1) ax.xaxis.set_major_locator(MultipleLocator(10)) ax.xaxis.set_minor_locator(AutoMinorLocator(1)) ax.yaxis.set_major_locator(MultipleLocator(0.1)) ax.yaxis.set_minor_locator(AutoMinorLocator(5)) plt.grid(which="major", color="red", linestyle="dotted") plt.grid(which="minor", color="orange", linestyle="dotted") plt.legend() plt.tight_layout() plt.savefig(f"gat_loss_{n_running}.png") return best_val_score, final_test_score