def test(model, dataloader, args): model.eval() epoch_stats = EpochStats() for batch_idx, batch in enumerate(tqdm(dataloader, desc="Batch")): _, test_batch, _ = concurrent_multi_task_train_test_split( batch, False, tasks=args.tasks) test_batch = test_batch[0] test_batch = test_batch.to(args.device) with torch.no_grad(): gc_test_logit, nc_test_logit, lp_test_logit = model(test_batch) # GC if "gc" in args.tasks: gc_loss = F.cross_entropy(gc_test_logit, test_batch.y) with torch.no_grad(): gc_acc = ut.get_accuracy(gc_test_logit, test_batch.y) epoch_stats.update("gc", test_batch, gc_loss, gc_acc, False) #NC if "nc" in args.tasks: node_labels = test_batch.node_y.argmax(1) train_mask = test_batch.train_mask.squeeze() test_mask = (train_mask == 0).float() nc_loss = F.cross_entropy(nc_test_logit[test_mask == 1], node_labels[test_mask == 1]) with torch.no_grad(): nc_acc = ut.get_accuracy(nc_test_logit[test_mask == 1], node_labels[test_mask == 1]) epoch_stats.update("nc", test_batch, nc_loss, nc_acc, False) # LP if "lp" in args.tasks: test_link_labels = data_utils.get_link_labels( test_batch.pos_edge_index, test_batch.neg_edge_index) lp_loss = F.binary_cross_entropy_with_logits( lp_test_logit.squeeze(), test_link_labels) with torch.no_grad(): test_labels = test_link_labels.detach().cpu().numpy() test_predictions = lp_test_logit.detach().cpu().numpy() lp_acc = roc_auc_score(test_labels, test_predictions.squeeze()) epoch_stats.update("lp", test_batch, lp_loss, lp_acc, False) tasks_test_stats = epoch_stats.get_average_stats() bl_ut.print_test_stats(tasks_test_stats) return tasks_test_stats
def eval_baseline_nn_output_model(output_model, dataloader, output_task, device="cpu"): output_model.eval() epoch_stats = EpochStats() for batch_idx, batch in enumerate(tqdm(dataloader, desc="Eval Batch")): batch = prepare_batch_for_task(batch, output_task, train=False) batch = batch.to(device) with torch.no_grad(): # Forward pass if output_task == "gc": test_logit = output_model(batch.node_embeddings, batch.batch) elif output_task == "nc": test_logit = output_model(batch.node_embeddings) elif output_task == "lp": test_logit = output_model(batch.node_embeddings, batch.pos_edge_index, batch.neg_edge_index) # Evaluate Loss and Accuracy if output_task == "gc": loss = F.cross_entropy(test_logit, batch.y) with torch.no_grad(): acc = ut.get_accuracy(test_logit, batch.y) elif output_task == "nc": node_labels = batch.node_y.argmax(1) train_mask = batch.train_mask.squeeze() test_mask = (train_mask==0).float() loss = F.cross_entropy(test_logit[test_mask==1], node_labels[test_mask==1]) with torch.no_grad(): acc = ut.get_accuracy(test_logit[test_mask==1], node_labels[test_mask==1]) elif output_task == "lp": test_link_labels = data_utils.get_link_labels(batch.pos_edge_index, batch.neg_edge_index) loss = F.binary_cross_entropy_with_logits(test_logit.squeeze(), test_link_labels) with torch.no_grad(): test_labels = test_link_labels.detach().cpu().numpy() test_predictions = test_logit.detach().cpu().numpy() acc = roc_auc_score(test_labels, test_predictions.squeeze()) epoch_stats.update(output_task, batch, loss, acc, False) task_test_stats = epoch_stats.get_average_stats() bl_ut.print_test_stats(task_test_stats) return task_test_stats
def get_data_for_linear_classifier(data, task, shuffle=True): X = [] y = [] for d in data: node_embeddings = d.node_embeddings.detach().cpu().numpy() #print(node_embeddings.shape) if task == "gc": X.append(node_embeddings.mean(axis=0)) y.append(d.y.detach().cpu().numpy()) elif task == "nc": #train_mask = d.train_mask.detach().cpu().numpy() X.append(node_embeddings) node_labels = d.node_y.argmax(1).detach().cpu().numpy() node_labels = np.expand_dims(node_labels, axis=1) y.append(node_labels) elif task == "lp": train_data_list, test_data_list = data_utils.prepare_data_for_link_prediction([d], train_ratio=0.9, neg_to_pos_edge_ratio=1, rnd_labeled_edges=False) pos_edge_idx = test_data_list[0].pos_edge_index.detach().cpu().numpy() neg_edge_idx = test_data_list[0].neg_edge_index.detach().cpu().numpy() lp_labels = data_utils.get_link_labels(test_data_list[0].pos_edge_index, test_data_list[0].neg_edge_index).detach().cpu().numpy() node_a = np.take(node_embeddings, np.concatenate((pos_edge_idx[0], neg_edge_idx[0])), axis=0) node_b = np.take(node_embeddings, np.concatenate((pos_edge_idx[1], neg_edge_idx[1])), axis=0) X.append(np.concatenate((node_a, node_b), axis=1)) y.append(np.expand_dims(lp_labels, axis=1)) X = np.vstack(X) y = np.vstack(y) if shuffle: perm = np.arange(X.shape[0]) np.random.shuffle(perm) X = X[perm] y = y[perm] return X, y
def test(model, dataloader, args): model.eval() epoch_stats = EpochStats() for batch_idx, batch in enumerate(tqdm(dataloader, desc="Batch")): test_batch = prepare_batch_for_task(batch, args.task, train=False) test_batch = test_batch.to(args.device) with torch.no_grad(): test_logit = model(test_batch) if args.task == "gc": loss = F.cross_entropy(test_logit, test_batch.y) with torch.no_grad(): acc = ut.get_accuracy(test_logit, test_batch.y) elif args.task == "nc": node_labels = test_batch.node_y.argmax(1) train_mask = test_batch.train_mask.squeeze() test_mask = (train_mask == 0).float() loss = F.cross_entropy(test_logit[test_mask == 1], node_labels[test_mask == 1]) with torch.no_grad(): acc = ut.get_accuracy(test_logit[test_mask == 1], node_labels[test_mask == 1]) elif args.task == "lp": test_link_labels = data_utils.get_link_labels( test_batch.pos_edge_index, test_batch.neg_edge_index) loss = F.binary_cross_entropy_with_logits( test_logit.squeeze(), test_link_labels) with torch.no_grad(): test_labels = test_link_labels.detach().cpu().numpy() test_predictions = test_logit.detach().cpu().numpy() acc = roc_auc_score(test_labels, test_predictions.squeeze()) epoch_stats.update(args.task, test_batch, loss, acc, False) task_test_stats = epoch_stats.get_average_stats() bl_ut.print_test_stats(task_test_stats) return task_test_stats
def train_baseline_nn_output_model(output_model, dataloader, output_task, epochs, lr, early_stopping=False, es_tmpdir=None, val_dataloader=None, device="cpu"): output_model.train() optimizer = torch.optim.Adam(output_model.parameters(), lr=lr) if early_stopping: best_val_score = 0 if not es_tmpdir: es_tmpdir = "emb_to_"+output_task+"_bst_early_stopping_tmp" for epoch in trange(epochs, desc="Epoch"): epoch_stats = EpochStats() for batch_idx, batch in enumerate(tqdm(dataloader, desc="Train Batch")): optimizer.zero_grad() batch = prepare_batch_for_task(batch, output_task, train=True) batch = batch.to(device) # Forward pass if output_task == "gc": train_logit = output_model(batch.node_embeddings, batch.batch) elif output_task == "nc": train_logit = output_model(batch.node_embeddings) elif output_task == "lp": train_logit = output_model(batch.node_embeddings, batch.pos_edge_index, batch.neg_edge_index) # Evaluate Loss and Accuracy if output_task == "gc": loss = F.cross_entropy(train_logit, batch.y) with torch.no_grad(): acc = ut.get_accuracy(train_logit, batch.y) elif output_task == "nc": node_labels = batch.node_y.argmax(1) train_mask = batch.train_mask.squeeze() loss = F.cross_entropy(train_logit[train_mask==1], node_labels[train_mask==1]) with torch.no_grad(): acc = ut.get_accuracy(train_logit[train_mask==1], node_labels[train_mask==1]) elif output_task == "lp": train_link_labels = data_utils.get_link_labels(batch.pos_edge_index, batch.neg_edge_index) loss = F.binary_cross_entropy_with_logits(train_logit.squeeze(), train_link_labels) with torch.no_grad(): train_labels = train_link_labels.detach().cpu().numpy() train_predictions = train_logit.detach().cpu().numpy() acc = roc_auc_score(train_labels, train_predictions.squeeze()) epoch_stats.update(output_task, batch, loss, acc, True) # Backprop and update parameters loss.backward() optimizer.step() if early_stopping and epoch > 5 and epoch%5 == 0: model_copy = copy.deepcopy(output_model) tqdm.write("\nTest on Validation Set") val_stats = eval_baseline_nn_output_model(model_copy, val_dataloader, output_task, device=device) epoch_acc = val_stats[output_task]["acc"] if epoch_acc > best_val_score: best_val_score = epoch_acc model_copy.to("cpu") args = type('', (), {})() args.early_stopping_stats = val_stats # so it save them in file args.early_stopping_epoch_acc = epoch_acc args.early_stopping_epoch = epoch ut.save_model(model_copy, es_tmpdir, "best_val", args) task_epoch_stats = epoch_stats.get_average_stats() bl_ut.print_train_epoch_stats(epoch, task_epoch_stats) if early_stopping: ut.recover_early_stopping_best_weights(output_model, es_tmpdir)
def train(model, dataloader, args, val_dataloader=False): model.train() if args.weight_unc: log_var_nc = torch.zeros((1, ), requires_grad=True, device=args.device) log_var_gc = torch.zeros((1, ), requires_grad=True, device=args.device) log_var_lp = torch.zeros((1, ), requires_grad=True, device=args.device) log_vars = {"nc": log_var_nc, "gc": log_var_gc, "lp": log_var_lp} p_list = [param for param in model.parameters() ] + [log_var_nc, log_var_gc, log_var_lp] optimizer = torch.optim.Adam(p_list, lr=args.lr) else: optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Training loop if args.early_stopping: best_val_score = 0 if not args.es_tmpdir: args.es_tmpdir = "bmt_early_stopping_tmp" for epoch in trange(args.epochs, desc="Epoch"): epoch_stats = EpochStats() for batch_idx, batch in enumerate(tqdm(dataloader, desc="Batch")): optimizer.zero_grad() _, train_batch, _ = concurrent_multi_task_train_test_split( batch, True, tasks=args.tasks) train_batch = train_batch[0] train_batch = train_batch.to(args.device) # Forward pass gc_train_logit, nc_train_logit, lp_train_logit = model(train_batch) # Evaluate Loss and Accuracy # GC gc_loss = nc_loss = lp_loss = 0 if "gc" in args.tasks: gc_loss = F.cross_entropy(gc_train_logit, train_batch.y) with torch.no_grad(): gc_acc = ut.get_accuracy(gc_train_logit, train_batch.y) epoch_stats.update("gc", train_batch, gc_loss, gc_acc, True) # NC if "nc" in args.tasks: node_labels = train_batch.node_y.argmax(1) train_mask = train_batch.train_mask.squeeze() nc_loss = F.cross_entropy(nc_train_logit[train_mask == 1], node_labels[train_mask == 1]) with torch.no_grad(): nc_acc = ut.get_accuracy(nc_train_logit[train_mask == 1], node_labels[train_mask == 1]) epoch_stats.update("nc", train_batch, nc_loss, nc_acc, True) # LP if "lp" in args.tasks: train_link_labels = data_utils.get_link_labels( train_batch.pos_edge_index, train_batch.neg_edge_index) lp_loss = F.binary_cross_entropy_with_logits( lp_train_logit.squeeze(), train_link_labels) with torch.no_grad(): train_labels = train_link_labels.detach().cpu().numpy() train_predictions = lp_train_logit.detach().cpu().numpy() lp_acc = roc_auc_score(train_labels, train_predictions.squeeze()) epoch_stats.update("lp", train_batch, lp_loss, lp_acc, True) if args.weight_unc: gc_precision = torch.exp( -log_vars["gc"]) if "gc" in args.tasks else 0 nc_precision = torch.exp( -log_vars["nc"]) if "nc" in args.tasks else 0 lp_precision = torch.exp( -log_vars["lp"]) if "lp" in args.tasks else 0 loss = torch.sum(gc_precision * gc_loss + log_vars["gc"], -1) + \ torch.sum(nc_precision * nc_loss + log_vars["nc"], -1) + \ torch.sum(lp_precision * lp_loss + log_vars["lp"], -1) else: loss = gc_loss + nc_loss + lp_loss # Backprop and update parameters loss.backward() optimizer.step() if args.early_stopping and epoch % 10 == 0: model_copy = copy.deepcopy(model) tqdm.write("\nTest on Validation Set") val_stats = test(model_copy, val_dataloader, args) tot_acc = 0 for task in val_stats: tot_acc += val_stats[task]["acc"] if tot_acc > best_val_score: best_val_score = tot_acc model_copy.to("cpu") args.early_stopping_stats = val_stats args.early_stopping_tot_acc = tot_acc args.early_stopping_epoch = epoch ut.save_model(model_copy, args.es_tmpdir, "best_val", args) tasks_epoch_stats = epoch_stats.get_average_stats() bl_ut.print_train_epoch_stats(epoch, tasks_epoch_stats) if args.early_stopping: ut.recover_early_stopping_best_weights(model, args.es_tmpdir)
def adapt_and_test(model, task, train_batch, test_batch, args, log_vars=None): """Adapt model on train_batch, and test it on test_batch. Returns statistics, inner loss, and outer loss (loss on test_batch with adapted parameters) that can be used for global update (outer loop).""" train_logit = model(train_batch, task_selector=task) if task == "gc": train_targets = train_batch.y test_targets = test_batch.y inner_loss = F.cross_entropy(train_logit, train_targets) if log_vars and (args.weight_unc == 1): precision = torch.exp(-log_vars[task]) inner_loss = torch.sum(precision * inner_loss + log_vars[task], -1) if log_vars[task].grad: log_vars[task].grad.zero_() model.zero_grad() adapted_params = update_parameters_gd(model, inner_loss, step_size=args.step_size, first_order=args.first_order) test_logit = model(test_batch, task_selector=task, params=adapted_params) outer_loss = F.cross_entropy(test_logit, test_targets) with torch.no_grad(): test_acc = ut.get_accuracy(test_logit, test_targets) elif task == "nc": node_labels = train_batch.node_y.argmax(1) train_mask = train_batch.train_mask.squeeze() test_mask = (train_mask == 0).float() inner_loss = F.cross_entropy(train_logit[train_mask == 1], node_labels[train_mask == 1]) if log_vars and (args.weight_unc == 1): precision = torch.exp(-log_vars[task]) inner_loss = torch.sum(precision * inner_loss + log_vars[task], -1) if log_vars[task].grad: log_vars[task].grad.zero_() model.zero_grad() adapted_params = update_parameters_gd(model, inner_loss, step_size=args.step_size, first_order=args.first_order) test_logit = model(train_batch, task_selector=task, params=adapted_params) outer_loss = F.cross_entropy(test_logit[test_mask == 1], node_labels[test_mask == 1]) with torch.no_grad(): test_acc = ut.get_accuracy(test_logit[test_mask == 1], node_labels[test_mask == 1]) elif task == "lp": train_link_labels = data_utils.get_link_labels( train_batch.pos_edge_index, train_batch.neg_edge_index) test_link_labels = data_utils.get_link_labels( test_batch.pos_edge_index, test_batch.neg_edge_index) inner_loss = F.binary_cross_entropy_with_logits( train_logit.squeeze(), train_link_labels) if log_vars and (args.weight_unc == 1): precision = torch.exp(-log_vars[task]) inner_loss = torch.sum(precision * inner_loss + log_vars[task], -1) if log_vars[task].grad: log_vars[task].grad.zero_() model.zero_grad() adapted_params = update_parameters_gd(model, inner_loss, step_size=args.step_size, first_order=args.first_order) test_logit = model(test_batch, task_selector=task, params=adapted_params) outer_loss = F.binary_cross_entropy_with_logits( test_logit.squeeze(), test_link_labels) with torch.no_grad(): #test_logit = torch.sigmoid(test_logit) test_logit = test_logit.detach().cpu().numpy() test_link_labels = test_link_labels.detach().cpu().numpy() try: test_acc = torch.tensor( roc_auc_score(test_link_labels, test_logit.squeeze())) except ValueError: print("Problem in AUC") print("Test Logit: {},\n Test Link Labels: {}".format( test_logit, test_link_labels)) test_acc = torch.tensor(0.0) elif isinstance(task, list): # we are in the concurrent case inner_loss = {} if "gc" in task: gc_logit = train_logit["gc"] gc_train_targets = train_batch.y gc_test_targets = test_batch.y inner_loss["gc"] = F.cross_entropy(gc_logit, gc_train_targets) if "nc" in task: nc_logit = train_logit["nc"] train_node_labels = train_batch.node_y.argmax(1) nc_train_mask = train_batch.train_mask.squeeze() test_node_labels = test_batch.node_y.argmax(1) nc_test_mask = (test_batch.train_mask.squeeze() == 0).float() inner_loss["nc"] = F.cross_entropy( nc_logit[nc_train_mask == 1], train_node_labels[nc_train_mask == 1]) if "lp" in task: lp_logit = train_logit["lp"] train_link_labels = data_utils.get_link_labels( train_batch.pos_edge_index, train_batch.neg_edge_index) test_link_labels = data_utils.get_link_labels( test_batch.pos_edge_index, test_batch.neg_edge_index) inner_loss["lp"] = F.binary_cross_entropy_with_logits( lp_logit.squeeze(), train_link_labels) inner_sum = torch.tensor(0.).to(args.device) if log_vars and (args.weight_unc == 1): for t in task: precision = torch.exp(-log_vars[t]) inner_sum += torch.sum(precision * inner_loss[t] + log_vars[t], -1) if log_vars[t].grad: log_vars[t].grad.zero_() else: for t in task: inner_sum += inner_loss[t] model.zero_grad() adapted_params = update_parameters_gd(model, inner_sum, step_size=args.step_size, first_order=args.first_order) test_logit = model(test_batch, task_selector=task, params=adapted_params) outer_loss = {} if "gc" in task: gc_test_logit = test_logit["gc"] outer_loss["gc"] = F.cross_entropy(gc_test_logit, gc_test_targets) if "nc" in task: nc_test_logit = test_logit["nc"] outer_loss["nc"] = F.cross_entropy( nc_test_logit[nc_test_mask == 1], test_node_labels[nc_test_mask == 1]) if "lp" in task: lp_test_logit = test_logit["lp"] outer_loss["lp"] = F.binary_cross_entropy_with_logits( lp_test_logit.squeeze(), test_link_labels) test_acc = {} with torch.no_grad(): if "gc" in task: test_acc["gc"] = ut.get_accuracy(gc_test_logit, gc_test_targets) if "nc" in task: test_acc["nc"] = ut.get_accuracy( nc_test_logit[nc_test_mask == 1], test_node_labels[nc_test_mask == 1]) if "lp" in task: lp_test_logit = lp_test_logit.detach().cpu().numpy() test_link_labels = test_link_labels.detach().cpu().numpy() try: test_acc["lp"] = torch.tensor( roc_auc_score(test_link_labels, lp_test_logit.squeeze())) except ValueError: print("Problem in AUC") print("Test Logit: {},\n Test Link Labels: {}".format( lp_test_logit, test_link_labels)) test_acc["lp"] = torch.tensor(0.0) return outer_loss, inner_loss, test_acc
def train(model, dataloader, args, val_dataloader=None): model.train() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Training loop if args.early_stopping: best_val_score = 0 if not args.es_tmpdir: args.es_tmpdir = args.task + "_bst_early_stopping_tmp" for epoch in trange(args.epochs, desc="Epoch"): epoch_stats = EpochStats() for batch_idx, batch in enumerate(tqdm(dataloader, desc="Batch")): optimizer.zero_grad() train_batch = prepare_batch_for_task(batch, args.task, train=True) train_batch = train_batch.to(args.device) # Forward pass train_logit = model(train_batch) # Evaluate Loss and Accuracy if args.task == "gc": loss = F.cross_entropy(train_logit, train_batch.y) with torch.no_grad(): acc = ut.get_accuracy(train_logit, train_batch.y) elif args.task == "nc": node_labels = train_batch.node_y.argmax(1) train_mask = train_batch.train_mask.squeeze() loss = F.cross_entropy(train_logit[train_mask == 1], node_labels[train_mask == 1]) with torch.no_grad(): acc = ut.get_accuracy(train_logit[train_mask == 1], node_labels[train_mask == 1]) elif args.task == "lp": train_link_labels = data_utils.get_link_labels( train_batch.pos_edge_index, train_batch.neg_edge_index) loss = F.binary_cross_entropy_with_logits( train_logit.squeeze(), train_link_labels) with torch.no_grad(): train_labels = train_link_labels.detach().cpu().numpy() train_predictions = train_logit.detach().cpu().numpy() try: acc = roc_auc_score(train_labels, train_predictions.squeeze()) except ValueError: auc = 0.0 epoch_stats.update(args.task, train_batch, loss, acc, True) # Backprop and update parameters loss.backward() optimizer.step() if args.early_stopping and epoch % 10 == 0: model_copy = copy.deepcopy(model) tqdm.write("\nTest on Validation Set") val_stats = test(model_copy, val_dataloader, args) epoch_acc = val_stats[args.task]["acc"] if epoch_acc > best_val_score: best_val_score = epoch_acc model_copy.to("cpu") args.early_stopping_stats = val_stats args.early_stopping_epoch_acc = epoch_acc args.early_stopping_epoch = epoch ut.save_model(model_copy, args.es_tmpdir, "best_val", args) task_epoch_stats = epoch_stats.get_average_stats() bl_ut.print_train_epoch_stats(epoch, task_epoch_stats) if args.early_stopping: ut.recover_early_stopping_best_weights(model, args.es_tmpdir)