示例#1
0
def test(model, dataloader, args):
    model.eval()
    epoch_stats = EpochStats()
    for batch_idx, batch in enumerate(tqdm(dataloader, desc="Batch")):
        _, test_batch, _ = concurrent_multi_task_train_test_split(
            batch, False, tasks=args.tasks)
        test_batch = test_batch[0]
        test_batch = test_batch.to(args.device)
        with torch.no_grad():
            gc_test_logit, nc_test_logit, lp_test_logit = model(test_batch)
            # GC
            if "gc" in args.tasks:
                gc_loss = F.cross_entropy(gc_test_logit, test_batch.y)
                with torch.no_grad():
                    gc_acc = ut.get_accuracy(gc_test_logit, test_batch.y)
                epoch_stats.update("gc", test_batch, gc_loss, gc_acc, False)
            #NC
            if "nc" in args.tasks:
                node_labels = test_batch.node_y.argmax(1)
                train_mask = test_batch.train_mask.squeeze()
                test_mask = (train_mask == 0).float()
                nc_loss = F.cross_entropy(nc_test_logit[test_mask == 1],
                                          node_labels[test_mask == 1])
                with torch.no_grad():
                    nc_acc = ut.get_accuracy(nc_test_logit[test_mask == 1],
                                             node_labels[test_mask == 1])
                epoch_stats.update("nc", test_batch, nc_loss, nc_acc, False)
            # LP
            if "lp" in args.tasks:
                test_link_labels = data_utils.get_link_labels(
                    test_batch.pos_edge_index, test_batch.neg_edge_index)
                lp_loss = F.binary_cross_entropy_with_logits(
                    lp_test_logit.squeeze(), test_link_labels)
                with torch.no_grad():
                    test_labels = test_link_labels.detach().cpu().numpy()
                    test_predictions = lp_test_logit.detach().cpu().numpy()
                    lp_acc = roc_auc_score(test_labels,
                                           test_predictions.squeeze())
                epoch_stats.update("lp", test_batch, lp_loss, lp_acc, False)

    tasks_test_stats = epoch_stats.get_average_stats()
    bl_ut.print_test_stats(tasks_test_stats)
    return tasks_test_stats
示例#2
0
def eval_baseline_nn_output_model(output_model, dataloader, output_task, device="cpu"):
    output_model.eval()
    epoch_stats = EpochStats()
    for batch_idx, batch in enumerate(tqdm(dataloader, desc="Eval Batch")):
        batch = prepare_batch_for_task(batch, output_task, train=False)
        batch = batch.to(device)
        with torch.no_grad():
            # Forward pass 
            if output_task == "gc":
                test_logit = output_model(batch.node_embeddings, batch.batch)
            elif output_task == "nc":
                test_logit = output_model(batch.node_embeddings)
            elif output_task == "lp":
                test_logit = output_model(batch.node_embeddings, batch.pos_edge_index, batch.neg_edge_index)                

            # Evaluate Loss and Accuracy
            if output_task == "gc":
                loss = F.cross_entropy(test_logit, batch.y)
                with torch.no_grad():
                    acc = ut.get_accuracy(test_logit, batch.y)
            elif output_task == "nc":
                node_labels = batch.node_y.argmax(1)
                train_mask = batch.train_mask.squeeze()
                test_mask = (train_mask==0).float()
                loss = F.cross_entropy(test_logit[test_mask==1], node_labels[test_mask==1])
                with torch.no_grad():
                    acc = ut.get_accuracy(test_logit[test_mask==1], node_labels[test_mask==1])
            elif output_task == "lp":
                test_link_labels = data_utils.get_link_labels(batch.pos_edge_index, batch.neg_edge_index)
                loss = F.binary_cross_entropy_with_logits(test_logit.squeeze(), test_link_labels)
                with torch.no_grad():
                    test_labels = test_link_labels.detach().cpu().numpy()
                    test_predictions = test_logit.detach().cpu().numpy()
                    acc = roc_auc_score(test_labels, test_predictions.squeeze())

            epoch_stats.update(output_task, batch, loss, acc, False)

    task_test_stats = epoch_stats.get_average_stats()
    bl_ut.print_test_stats(task_test_stats)
    return task_test_stats
示例#3
0
def get_data_for_linear_classifier(data, task, shuffle=True):
    X = []
    y = []
    for d in data:
        node_embeddings = d.node_embeddings.detach().cpu().numpy()
        #print(node_embeddings.shape)
        if task == "gc":
            X.append(node_embeddings.mean(axis=0))   
            y.append(d.y.detach().cpu().numpy())
        elif task == "nc":
            #train_mask = d.train_mask.detach().cpu().numpy()
            X.append(node_embeddings)
            node_labels = d.node_y.argmax(1).detach().cpu().numpy() 
            node_labels = np.expand_dims(node_labels, axis=1)
            y.append(node_labels) 
        elif task == "lp":
            train_data_list, test_data_list = data_utils.prepare_data_for_link_prediction([d], 
                                                                                          train_ratio=0.9,
                                                                                          neg_to_pos_edge_ratio=1,
                                                                                          rnd_labeled_edges=False)

            pos_edge_idx = test_data_list[0].pos_edge_index.detach().cpu().numpy()
            neg_edge_idx = test_data_list[0].neg_edge_index.detach().cpu().numpy()
            lp_labels = data_utils.get_link_labels(test_data_list[0].pos_edge_index, test_data_list[0].neg_edge_index).detach().cpu().numpy()
 
            node_a = np.take(node_embeddings, np.concatenate((pos_edge_idx[0], neg_edge_idx[0])), axis=0)
            node_b = np.take(node_embeddings, np.concatenate((pos_edge_idx[1], neg_edge_idx[1])), axis=0)
            X.append(np.concatenate((node_a, node_b), axis=1))
            y.append(np.expand_dims(lp_labels, axis=1))
    
    X = np.vstack(X)
    y = np.vstack(y)
    if shuffle:
        perm = np.arange(X.shape[0])
        np.random.shuffle(perm)
        X = X[perm]
        y = y[perm]
    return X, y
示例#4
0
def test(model, dataloader, args):
    model.eval()
    epoch_stats = EpochStats()
    for batch_idx, batch in enumerate(tqdm(dataloader, desc="Batch")):
        test_batch = prepare_batch_for_task(batch, args.task, train=False)
        test_batch = test_batch.to(args.device)
        with torch.no_grad():
            test_logit = model(test_batch)
            if args.task == "gc":
                loss = F.cross_entropy(test_logit, test_batch.y)
                with torch.no_grad():
                    acc = ut.get_accuracy(test_logit, test_batch.y)
            elif args.task == "nc":
                node_labels = test_batch.node_y.argmax(1)
                train_mask = test_batch.train_mask.squeeze()
                test_mask = (train_mask == 0).float()
                loss = F.cross_entropy(test_logit[test_mask == 1],
                                       node_labels[test_mask == 1])
                with torch.no_grad():
                    acc = ut.get_accuracy(test_logit[test_mask == 1],
                                          node_labels[test_mask == 1])
            elif args.task == "lp":
                test_link_labels = data_utils.get_link_labels(
                    test_batch.pos_edge_index, test_batch.neg_edge_index)
                loss = F.binary_cross_entropy_with_logits(
                    test_logit.squeeze(), test_link_labels)
                with torch.no_grad():
                    test_labels = test_link_labels.detach().cpu().numpy()
                    test_predictions = test_logit.detach().cpu().numpy()
                    acc = roc_auc_score(test_labels,
                                        test_predictions.squeeze())

            epoch_stats.update(args.task, test_batch, loss, acc, False)

    task_test_stats = epoch_stats.get_average_stats()
    bl_ut.print_test_stats(task_test_stats)
    return task_test_stats
示例#5
0
def train_baseline_nn_output_model(output_model, dataloader, output_task, epochs, lr, early_stopping=False, es_tmpdir=None, val_dataloader=None, device="cpu"):
    output_model.train()
    optimizer = torch.optim.Adam(output_model.parameters(), lr=lr)

    if early_stopping:
        best_val_score = 0
        if not es_tmpdir:
            es_tmpdir = "emb_to_"+output_task+"_bst_early_stopping_tmp"
    for epoch in trange(epochs, desc="Epoch"):
        epoch_stats = EpochStats()
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Train Batch")):
            optimizer.zero_grad()

            batch = prepare_batch_for_task(batch, output_task, train=True)
            batch = batch.to(device)

            # Forward pass 
            if output_task == "gc":
                train_logit = output_model(batch.node_embeddings, batch.batch)
            elif output_task == "nc":
                train_logit = output_model(batch.node_embeddings)
            elif output_task == "lp":
                train_logit = output_model(batch.node_embeddings, batch.pos_edge_index, batch.neg_edge_index)                

            # Evaluate Loss and Accuracy
            if output_task == "gc":
                loss = F.cross_entropy(train_logit, batch.y)
                with torch.no_grad():
                    acc = ut.get_accuracy(train_logit, batch.y)
            elif output_task == "nc":
                node_labels = batch.node_y.argmax(1)
                train_mask = batch.train_mask.squeeze()
                loss = F.cross_entropy(train_logit[train_mask==1], node_labels[train_mask==1])
                with torch.no_grad():
                    acc = ut.get_accuracy(train_logit[train_mask==1], node_labels[train_mask==1])
            elif output_task == "lp":
                train_link_labels = data_utils.get_link_labels(batch.pos_edge_index, batch.neg_edge_index)
                loss = F.binary_cross_entropy_with_logits(train_logit.squeeze(), train_link_labels)
                with torch.no_grad():
                    train_labels = train_link_labels.detach().cpu().numpy()
                    train_predictions = train_logit.detach().cpu().numpy()
                    acc = roc_auc_score(train_labels, train_predictions.squeeze())

            epoch_stats.update(output_task, batch, loss, acc, True)
            
            # Backprop and update parameters
            loss.backward()
            optimizer.step()
            
        if early_stopping and epoch > 5 and epoch%5 == 0:
            model_copy = copy.deepcopy(output_model)
            tqdm.write("\nTest on Validation Set")
            val_stats = eval_baseline_nn_output_model(model_copy, val_dataloader, output_task, device=device)
            epoch_acc = val_stats[output_task]["acc"]
            if epoch_acc > best_val_score:
                best_val_score = epoch_acc
                model_copy.to("cpu")
                args = type('', (), {})()
                args.early_stopping_stats = val_stats # so it save them in file
                args.early_stopping_epoch_acc = epoch_acc
                args.early_stopping_epoch = epoch
                ut.save_model(model_copy, es_tmpdir, "best_val", args)

        task_epoch_stats = epoch_stats.get_average_stats()
        bl_ut.print_train_epoch_stats(epoch, task_epoch_stats)

    if early_stopping:
        ut.recover_early_stopping_best_weights(output_model, es_tmpdir)
示例#6
0
def train(model, dataloader, args, val_dataloader=False):
    model.train()
    if args.weight_unc:
        log_var_nc = torch.zeros((1, ), requires_grad=True, device=args.device)
        log_var_gc = torch.zeros((1, ), requires_grad=True, device=args.device)
        log_var_lp = torch.zeros((1, ), requires_grad=True, device=args.device)
        log_vars = {"nc": log_var_nc, "gc": log_var_gc, "lp": log_var_lp}
        p_list = [param for param in model.parameters()
                  ] + [log_var_nc, log_var_gc, log_var_lp]
        optimizer = torch.optim.Adam(p_list, lr=args.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    if args.early_stopping:
        best_val_score = 0
        if not args.es_tmpdir:
            args.es_tmpdir = "bmt_early_stopping_tmp"
    for epoch in trange(args.epochs, desc="Epoch"):
        epoch_stats = EpochStats()
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Batch")):
            optimizer.zero_grad()

            _, train_batch, _ = concurrent_multi_task_train_test_split(
                batch, True, tasks=args.tasks)
            train_batch = train_batch[0]
            train_batch = train_batch.to(args.device)

            # Forward pass
            gc_train_logit, nc_train_logit, lp_train_logit = model(train_batch)

            # Evaluate Loss and Accuracy
            # GC
            gc_loss = nc_loss = lp_loss = 0
            if "gc" in args.tasks:
                gc_loss = F.cross_entropy(gc_train_logit, train_batch.y)
                with torch.no_grad():
                    gc_acc = ut.get_accuracy(gc_train_logit, train_batch.y)
                epoch_stats.update("gc", train_batch, gc_loss, gc_acc, True)
            # NC
            if "nc" in args.tasks:
                node_labels = train_batch.node_y.argmax(1)
                train_mask = train_batch.train_mask.squeeze()
                nc_loss = F.cross_entropy(nc_train_logit[train_mask == 1],
                                          node_labels[train_mask == 1])
                with torch.no_grad():
                    nc_acc = ut.get_accuracy(nc_train_logit[train_mask == 1],
                                             node_labels[train_mask == 1])
                epoch_stats.update("nc", train_batch, nc_loss, nc_acc, True)
            # LP
            if "lp" in args.tasks:
                train_link_labels = data_utils.get_link_labels(
                    train_batch.pos_edge_index, train_batch.neg_edge_index)
                lp_loss = F.binary_cross_entropy_with_logits(
                    lp_train_logit.squeeze(), train_link_labels)
                with torch.no_grad():
                    train_labels = train_link_labels.detach().cpu().numpy()
                    train_predictions = lp_train_logit.detach().cpu().numpy()
                    lp_acc = roc_auc_score(train_labels,
                                           train_predictions.squeeze())
                epoch_stats.update("lp", train_batch, lp_loss, lp_acc, True)

            if args.weight_unc:
                gc_precision = torch.exp(
                    -log_vars["gc"]) if "gc" in args.tasks else 0
                nc_precision = torch.exp(
                    -log_vars["nc"]) if "nc" in args.tasks else 0
                lp_precision = torch.exp(
                    -log_vars["lp"]) if "lp" in args.tasks else 0
                loss = torch.sum(gc_precision * gc_loss + log_vars["gc"], -1) + \
                       torch.sum(nc_precision * nc_loss + log_vars["nc"], -1) + \
                       torch.sum(lp_precision * lp_loss + log_vars["lp"], -1)
            else:
                loss = gc_loss + nc_loss + lp_loss

            # Backprop and update parameters
            loss.backward()
            optimizer.step()

        if args.early_stopping and epoch % 10 == 0:
            model_copy = copy.deepcopy(model)
            tqdm.write("\nTest on Validation Set")
            val_stats = test(model_copy, val_dataloader, args)
            tot_acc = 0
            for task in val_stats:
                tot_acc += val_stats[task]["acc"]
            if tot_acc > best_val_score:
                best_val_score = tot_acc
                model_copy.to("cpu")
                args.early_stopping_stats = val_stats
                args.early_stopping_tot_acc = tot_acc
                args.early_stopping_epoch = epoch
                ut.save_model(model_copy, args.es_tmpdir, "best_val", args)

        tasks_epoch_stats = epoch_stats.get_average_stats()
        bl_ut.print_train_epoch_stats(epoch, tasks_epoch_stats)

    if args.early_stopping:
        ut.recover_early_stopping_best_weights(model, args.es_tmpdir)
示例#7
0
文件: train.py 项目: lilleswing/SAME
def adapt_and_test(model, task, train_batch, test_batch, args, log_vars=None):
    """Adapt model on train_batch, and test it on test_batch. Returns statistics, inner loss,
    and outer loss (loss on test_batch with adapted parameters) that can be used for global 
    update (outer loop)."""
    train_logit = model(train_batch, task_selector=task)
    if task == "gc":
        train_targets = train_batch.y
        test_targets = test_batch.y

        inner_loss = F.cross_entropy(train_logit, train_targets)
        if log_vars and (args.weight_unc == 1):
            precision = torch.exp(-log_vars[task])
            inner_loss = torch.sum(precision * inner_loss + log_vars[task], -1)
            if log_vars[task].grad:
                log_vars[task].grad.zero_()
        model.zero_grad()
        adapted_params = update_parameters_gd(model,
                                              inner_loss,
                                              step_size=args.step_size,
                                              first_order=args.first_order)

        test_logit = model(test_batch,
                           task_selector=task,
                           params=adapted_params)
        outer_loss = F.cross_entropy(test_logit, test_targets)
        with torch.no_grad():
            test_acc = ut.get_accuracy(test_logit, test_targets)
    elif task == "nc":
        node_labels = train_batch.node_y.argmax(1)
        train_mask = train_batch.train_mask.squeeze()
        test_mask = (train_mask == 0).float()

        inner_loss = F.cross_entropy(train_logit[train_mask == 1],
                                     node_labels[train_mask == 1])
        if log_vars and (args.weight_unc == 1):
            precision = torch.exp(-log_vars[task])
            inner_loss = torch.sum(precision * inner_loss + log_vars[task], -1)
            if log_vars[task].grad:
                log_vars[task].grad.zero_()
        model.zero_grad()
        adapted_params = update_parameters_gd(model,
                                              inner_loss,
                                              step_size=args.step_size,
                                              first_order=args.first_order)

        test_logit = model(train_batch,
                           task_selector=task,
                           params=adapted_params)
        outer_loss = F.cross_entropy(test_logit[test_mask == 1],
                                     node_labels[test_mask == 1])
        with torch.no_grad():
            test_acc = ut.get_accuracy(test_logit[test_mask == 1],
                                       node_labels[test_mask == 1])
    elif task == "lp":
        train_link_labels = data_utils.get_link_labels(
            train_batch.pos_edge_index, train_batch.neg_edge_index)
        test_link_labels = data_utils.get_link_labels(
            test_batch.pos_edge_index, test_batch.neg_edge_index)

        inner_loss = F.binary_cross_entropy_with_logits(
            train_logit.squeeze(), train_link_labels)
        if log_vars and (args.weight_unc == 1):
            precision = torch.exp(-log_vars[task])
            inner_loss = torch.sum(precision * inner_loss + log_vars[task], -1)
            if log_vars[task].grad:
                log_vars[task].grad.zero_()

        model.zero_grad()
        adapted_params = update_parameters_gd(model,
                                              inner_loss,
                                              step_size=args.step_size,
                                              first_order=args.first_order)

        test_logit = model(test_batch,
                           task_selector=task,
                           params=adapted_params)
        outer_loss = F.binary_cross_entropy_with_logits(
            test_logit.squeeze(), test_link_labels)
        with torch.no_grad():
            #test_logit = torch.sigmoid(test_logit)
            test_logit = test_logit.detach().cpu().numpy()
            test_link_labels = test_link_labels.detach().cpu().numpy()
            try:
                test_acc = torch.tensor(
                    roc_auc_score(test_link_labels, test_logit.squeeze()))
            except ValueError:
                print("Problem in AUC")
                print("Test Logit: {},\n Test Link Labels: {}".format(
                    test_logit, test_link_labels))
                test_acc = torch.tensor(0.0)
    elif isinstance(task, list):  # we are in the concurrent case
        inner_loss = {}
        if "gc" in task:
            gc_logit = train_logit["gc"]
            gc_train_targets = train_batch.y
            gc_test_targets = test_batch.y
            inner_loss["gc"] = F.cross_entropy(gc_logit, gc_train_targets)
        if "nc" in task:
            nc_logit = train_logit["nc"]
            train_node_labels = train_batch.node_y.argmax(1)
            nc_train_mask = train_batch.train_mask.squeeze()
            test_node_labels = test_batch.node_y.argmax(1)
            nc_test_mask = (test_batch.train_mask.squeeze() == 0).float()
            inner_loss["nc"] = F.cross_entropy(
                nc_logit[nc_train_mask == 1],
                train_node_labels[nc_train_mask == 1])
        if "lp" in task:
            lp_logit = train_logit["lp"]
            train_link_labels = data_utils.get_link_labels(
                train_batch.pos_edge_index, train_batch.neg_edge_index)
            test_link_labels = data_utils.get_link_labels(
                test_batch.pos_edge_index, test_batch.neg_edge_index)
            inner_loss["lp"] = F.binary_cross_entropy_with_logits(
                lp_logit.squeeze(), train_link_labels)

        inner_sum = torch.tensor(0.).to(args.device)
        if log_vars and (args.weight_unc == 1):
            for t in task:
                precision = torch.exp(-log_vars[t])
                inner_sum += torch.sum(precision * inner_loss[t] + log_vars[t],
                                       -1)
                if log_vars[t].grad:
                    log_vars[t].grad.zero_()
        else:
            for t in task:
                inner_sum += inner_loss[t]

        model.zero_grad()
        adapted_params = update_parameters_gd(model,
                                              inner_sum,
                                              step_size=args.step_size,
                                              first_order=args.first_order)

        test_logit = model(test_batch,
                           task_selector=task,
                           params=adapted_params)

        outer_loss = {}
        if "gc" in task:
            gc_test_logit = test_logit["gc"]
            outer_loss["gc"] = F.cross_entropy(gc_test_logit, gc_test_targets)
        if "nc" in task:
            nc_test_logit = test_logit["nc"]
            outer_loss["nc"] = F.cross_entropy(
                nc_test_logit[nc_test_mask == 1],
                test_node_labels[nc_test_mask == 1])
        if "lp" in task:
            lp_test_logit = test_logit["lp"]
            outer_loss["lp"] = F.binary_cross_entropy_with_logits(
                lp_test_logit.squeeze(), test_link_labels)

        test_acc = {}
        with torch.no_grad():
            if "gc" in task:
                test_acc["gc"] = ut.get_accuracy(gc_test_logit,
                                                 gc_test_targets)
            if "nc" in task:
                test_acc["nc"] = ut.get_accuracy(
                    nc_test_logit[nc_test_mask == 1],
                    test_node_labels[nc_test_mask == 1])
            if "lp" in task:
                lp_test_logit = lp_test_logit.detach().cpu().numpy()
                test_link_labels = test_link_labels.detach().cpu().numpy()
                try:
                    test_acc["lp"] = torch.tensor(
                        roc_auc_score(test_link_labels,
                                      lp_test_logit.squeeze()))
                except ValueError:
                    print("Problem in AUC")
                    print("Test Logit: {},\n Test Link Labels: {}".format(
                        lp_test_logit, test_link_labels))
                    test_acc["lp"] = torch.tensor(0.0)

    return outer_loss, inner_loss, test_acc
示例#8
0
def train(model, dataloader, args, val_dataloader=None):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    if args.early_stopping:
        best_val_score = 0
        if not args.es_tmpdir:
            args.es_tmpdir = args.task + "_bst_early_stopping_tmp"
    for epoch in trange(args.epochs, desc="Epoch"):
        epoch_stats = EpochStats()
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Batch")):
            optimizer.zero_grad()

            train_batch = prepare_batch_for_task(batch, args.task, train=True)
            train_batch = train_batch.to(args.device)

            # Forward pass
            train_logit = model(train_batch)

            # Evaluate Loss and Accuracy
            if args.task == "gc":
                loss = F.cross_entropy(train_logit, train_batch.y)
                with torch.no_grad():
                    acc = ut.get_accuracy(train_logit, train_batch.y)
            elif args.task == "nc":
                node_labels = train_batch.node_y.argmax(1)
                train_mask = train_batch.train_mask.squeeze()
                loss = F.cross_entropy(train_logit[train_mask == 1],
                                       node_labels[train_mask == 1])
                with torch.no_grad():
                    acc = ut.get_accuracy(train_logit[train_mask == 1],
                                          node_labels[train_mask == 1])
            elif args.task == "lp":
                train_link_labels = data_utils.get_link_labels(
                    train_batch.pos_edge_index, train_batch.neg_edge_index)
                loss = F.binary_cross_entropy_with_logits(
                    train_logit.squeeze(), train_link_labels)
                with torch.no_grad():
                    train_labels = train_link_labels.detach().cpu().numpy()
                    train_predictions = train_logit.detach().cpu().numpy()
                    try:
                        acc = roc_auc_score(train_labels,
                                            train_predictions.squeeze())
                    except ValueError:
                        auc = 0.0

            epoch_stats.update(args.task, train_batch, loss, acc, True)

            # Backprop and update parameters
            loss.backward()
            optimizer.step()

        if args.early_stopping and epoch % 10 == 0:
            model_copy = copy.deepcopy(model)
            tqdm.write("\nTest on Validation Set")
            val_stats = test(model_copy, val_dataloader, args)
            epoch_acc = val_stats[args.task]["acc"]
            if epoch_acc > best_val_score:
                best_val_score = epoch_acc
                model_copy.to("cpu")
                args.early_stopping_stats = val_stats
                args.early_stopping_epoch_acc = epoch_acc
                args.early_stopping_epoch = epoch
                ut.save_model(model_copy, args.es_tmpdir, "best_val", args)

        task_epoch_stats = epoch_stats.get_average_stats()
        bl_ut.print_train_epoch_stats(epoch, task_epoch_stats)

    if args.early_stopping:
        ut.recover_early_stopping_best_weights(model, args.es_tmpdir)