Exemple #1
0
def train(args):
    def _get_dataloader(datasubset,
                        tokenizer,
                        device,
                        args,
                        subset_classes=True):
        """
        Get specific dataloader.

        Args:
            datasubset ([type]): [description]
            tokenizer ([type]): [description]
            device ([type]): [description]
            args ([type]): [description]

        Returns:
            dataloader
        """

        if subset_classes:
            dataloader = StratifiedLoaderwClassesSubset(
                datasubset,
                k=args['k'],
                max_classes=args['max_classes'],
                max_batch_size=args['max_batch_size'],
                tokenizer=tokenizer,
                device=device,
                shuffle=True,
                verbose=False)
        else:
            dataloader = StratifiedLoader(
                datasubset,
                k=args['k'],
                max_batch_size=args['max_batch_size'],
                tokenizer=tokenizer,
                device=device,
                shuffle=True,
                verbose=False)

        return dataloader

    def _adapt_and_fit(support_labels,
                       support_input,
                       query_labels,
                       query_input,
                       loss_fn,
                       model_init,
                       args,
                       mode="train"):
        """
        Adapts the init model to a support set and computes loss on query set.

        Args:
            support_labels ([type]): [description]
            support_text ([type]): [description]
            query_labels ([type]): [description]
            query_text ([type]): [description]
            model_init ([type]): [description]
            args
            mode
        """

        #####################
        # Create model_task #
        #####################
        if (not args['dropout']) and mode == "train":
            for module in model_init.modules():
                if isinstance(module, nn.Dropout):
                    module.eval()
                else:
                    module.train()
        elif mode != "train":
            model_init.eval()
        else:
            model_init.train()

        model_task = deepcopy(model_init)

        for name, param in model_task.encoder.model.named_parameters():
            transformer_layer = re.search("(?:encoder\.layer\.)([0-9]+)", name)
            if transformer_layer and (int(transformer_layer.group(1)) >
                                      args['inner_nu']):
                param.requires_grad = True
            elif 'pooler' in name:
                param.requires_grad = False
            elif args['inner_nu'] < 0:
                param.requires_grad = True
            else:
                param.requires_grad = False

        model_task_optimizer = optim.SGD(model_task.parameters(),
                                         lr=args['inner_lr'])
        model_task.zero_grad()

        #######################
        # Generate prototypes #
        #######################

        labs = torch.sort(torch.unique(support_labels))[0]

        if (not args['kill_prototypes']):

            y = model_init(support_input)

            prototypes = torch.stack(
                [torch.mean(y[support_labels == c], dim=0) for c in labs])

            W_init = 2 * prototypes
            b_init = -torch.norm(prototypes, p=2, dim=1)**2

        else:

            W_init = torch.empty(
                (labs.size()[0],
                 model_init.out_dim)).to(model_task.get_device())
            nn.init.kaiming_normal_(W_init)

            b_init = torch.zeros((labs.size()[0])).to(model_task.get_device())

        W_task, b_task = W_init.detach(), b_init.detach()
        W_task.requires_grad, b_task.requires_grad = True, True

        #################
        # Adapt to data #
        #################
        for _ in range(args['n_inner']):

            y = model_task(support_input)
            logits = F.linear(y, W_task, b_task)

            inner_loss = loss_fn(logits, support_labels)

            W_task_grad, b_task_grad = torch.autograd.grad(inner_loss,\
                [W_task, b_task], retain_graph=True)

            inner_loss.backward()

            if args['clip_val'] > 0:
                torch.nn.utils.clip_grad_norm_(model_task.parameters(),
                                               args['clip_val'])

            model_task_optimizer.step()

            W_task = W_task - args['output_lr'] * W_task_grad
            b_task = b_task - args['output_lr'] * b_task_grad

            if args['print_inner_loss']:
                print(f"\tInner Loss: {inner_loss.detach().cpu().item()}")

        #########################
        # Validate on query set #
        #########################
        if mode == "train":
            for module in model_task.modules():
                if isinstance(module, nn.Dropout):
                    module.eval()

            W_task = W_init + (W_task - W_init).detach()
            b_task = b_init + (b_task - b_init).detach()

        y = model_task(query_input)
        logits = F.linear(y, W_task, b_task)

        outer_loss = loss_fn(logits, query_labels)

        if mode == "train":
            model_task_params = [
                param for param in model_task.parameters()
                if param.requires_grad
            ]
            model_task_grads = torch.autograd.grad(outer_loss,
                                                   model_task_params,
                                                   retain_graph=True)

            model_init_params = [
                param for param in model_init.parameters()
                if param.requires_grad
            ]

            model_init_grads = torch.autograd.grad(outer_loss,
                                                   model_init_params,
                                                   retain_graph=False,
                                                   allow_unused=True)

            model_init_grads = model_init_grads + model_task_grads

            for param, grad in zip(model_init_params, model_init_grads):
                if param.grad != None and grad != None:
                    param.grad += grad.detach()
                elif grad != None:
                    param.grad = grad.detach()
                else:
                    param.grad = None
        else:
            del model_task, W_task, b_task, W_task_grad, b_task_grad, W_init, b_init

        if outer_loss.detach().cpu().item() > 10:
            print(outer_loss.detach().cpu().item(),
                  inner_loss.detach().cpu().item())

        return logits.detach(), outer_loss.detach()

    #######################
    # Logging Directories #
    #######################
    log_dir = os.path.join(args['checkpoint_path'], args['version'])

    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(os.path.join(log_dir, 'tensorboard'), exist_ok=True)
    os.makedirs(os.path.join(log_dir, 'checkpoint'), exist_ok=True)
    #print(f"Saving models and logs to {log_dir}")

    checkpoint_save_path = os.path.join(log_dir, 'checkpoint')

    with open(os.path.join(log_dir, 'checkpoint', 'hparams.pickle'),
              'wb') as file:
        pickle.dump(args, file)

    ##########################
    # Device, Logging, Timer #
    ##########################

    set_seed(args['seed'])

    timer = Timer()

    device = torch.device('cuda' if (
        torch.cuda.is_available() and args['gpu']) else 'cpu')

    # Build the tensorboard writer
    writer = SummaryWriter(os.path.join(log_dir, 'tensorboard'))

    ###################
    # Load in dataset #
    ###################
    print("Data Prep")
    dataset = meta_dataset(include=args['include'], verbose=True)
    dataset.prep(text_tokenizer=manual_tokenizer)
    print("")

    ####################
    # Init models etc. #
    ####################
    model_init = SeqTransformer(args)
    tokenizer = AutoTokenizer.from_pretrained(args['encoder_name'])

    tokenizer.add_special_tokens({'additional_special_tokens': specials()})
    model_init.encoder.model.resize_token_embeddings(len(tokenizer.vocab))

    if args['optimizer'] == "Adam":
        meta_optimizer = optim.Adam(model_init.parameters(),
                                    lr=args['meta_lr'])
    elif args['optimizer'] == "SGD":
        meta_optimizer = optim.SGD(model_init.parameters(), lr=args['meta_lr'])

    meta_scheduler = get_constant_schedule_with_warmup(meta_optimizer,
                                                       args['warmup_steps'])
    reduceOnPlateau = optim.lr_scheduler.ReduceLROnPlateau(
        meta_optimizer,
        mode='max',
        factor=args['lr_reduce_factor'],
        patience=args['patience'],
        verbose=True)

    model_init = model_init.to(device)

    loss_fn = nn.CrossEntropyLoss()

    #################
    # Training loop #
    #################

    best_overall_acc_s = 0.0

    for episode in range(1, args['max_episodes'] + 1):

        outer_loss_agg, acc_agg, f1_agg = 0.0, 0.0, 0.0
        outer_loss_s_agg, acc_s_agg, f1_s_agg = 0.0, 0.0, 0.0

        for ii in range(1, args['n_outer'] + 1):
            #################
            # Sample a task #
            #################
            task = dataset_sampler(dataset, sampling_method='sqrt')

            datasubset = dataset.datasets[task]['train']

            dataloader = _get_dataloader(datasubset,
                                         tokenizer,
                                         device,
                                         args,
                                         subset_classes=args['subset_classes'])

            support_labels, support_input, query_labels, query_input = next(
                dataloader)

            logits, outer_loss = _adapt_and_fit(support_labels,
                                                support_input,
                                                query_labels,
                                                query_input,
                                                loss_fn,
                                                model_init,
                                                args,
                                                mode="train")

            ######################
            # Inner Loop Logging #
            ######################
            with torch.no_grad():
                mets = logging_metrics(logits.detach().cpu(),
                                       query_labels.detach().cpu())
                outer_loss_ = outer_loss.detach().cpu().item()
                acc = mets['acc']
                f1 = mets['f1']

                outer_loss_s = outer_loss_ / np.log(dataloader.n_classes)
                acc_s = acc / (1 / dataloader.n_classes)
                f1_s = f1 / (1 / dataloader.n_classes)

                outer_loss_agg += outer_loss_ / args['n_outer']
                acc_agg += acc / args['n_outer']
                f1_agg += f1 / args['n_outer']

                outer_loss_s_agg += outer_loss_s / args['n_outer']
                acc_s_agg += acc_s / args['n_outer']
                f1_s_agg += f1_s / args['n_outer']

            print(
                "{:} | Train | Episode {:04}.{:02} | Task {:^20s}, N={:} | Loss {:5.2f}, Acc {:5.2f}, F1 {:5.2f} | Mem {:5.2f} GB"
                .format(
                    timer.dt(), episode, ii, task, dataloader.n_classes,
                    outer_loss_s if args['print_scaled'] else outer_loss_,
                    acc_s if args['print_scaled'] else acc,
                    f1_s if args['print_scaled'] else f1,
                    psutil.Process(os.getpid()).memory_info().rss / 1024**3))

            writer.add_scalars('Loss/Train', {task: outer_loss_}, episode)
            writer.add_scalars('Accuracy/Train', {task: acc}, episode)
            writer.add_scalars('F1/Train', {task: f1}, episode)

            writer.add_scalars('LossScaled/Train', {task: outer_loss_s},
                               episode)
            writer.add_scalars('AccuracyScaled/Train', {task: acc_s}, episode)
            writer.add_scalars('F1Scaled/Train', {task: f1_s}, episode)

            writer.flush()

        ############################
        # Init Model Backward Pass #
        ############################
        model_init_params = [
            param for param in model_init.parameters() if param.requires_grad
        ]
        #for param in model_init_params:
        #    param.grad = param.grad #/ args['n_outer']

        if args['clip_val'] > 0:
            torch.nn.utils.clip_grad_norm_(model_init_params, args['clip_val'])

        meta_optimizer.step()
        meta_scheduler.step()

        if args['warmup_steps'] <= episode + 1:
            meta_optimizer.zero_grad()

        #####################
        # Aggregate Logging #
        #####################
        print(
            "{:} | MACRO-AGG | Train | Episode {:04} | Loss {:5.2f}, Acc {:5.2f}, F1 {:5.2f}\n"
            .format(
                timer.dt(), episode,
                outer_loss_s_agg if args['print_scaled'] else outer_loss_agg,
                acc_s_agg if args['print_scaled'] else acc_agg,
                f1_s_agg if args['print_scaled'] else f1_agg))

        writer.add_scalar('Loss/MacroTrain', outer_loss_agg, episode)
        writer.add_scalar('Accuracy/MacroTrain', acc_agg, episode)
        writer.add_scalar('F1/MacroTrain', f1_agg, episode)

        writer.add_scalar('LossScaled/MacroTrain', outer_loss_s_agg, episode)
        writer.add_scalar('AccuracyScaled/MacroTrain', acc_s_agg, episode)
        writer.add_scalar('F1Scaled/MacroTrain', f1_s_agg, episode)

        writer.flush()

        ##############
        # Evaluation #
        ##############
        if (episode % args['eval_every_n']) == 0 or episode == 1:

            overall_loss, overall_acc, overall_f1 = [], [], []
            overall_loss_s, overall_acc_s, overall_f1_s = [], [], []
            ###################
            # Individual Task #
            ###################
            for task in dataset.lens.keys():
                datasubset = dataset.datasets[task]['validation']

                task_loss, task_acc, task_f1 = [], [], []
                task_loss_s, task_acc_s, task_f1_s = [], [], []
                for _ in range(args['n_eval_per_task']):

                    dataloader = _get_dataloader(
                        datasubset,
                        tokenizer,
                        device,
                        args,
                        subset_classes=args['subset_classes'])
                    support_labels, support_input, query_labels, query_input = next(
                        dataloader)

                    logits, loss = _adapt_and_fit(support_labels,
                                                  support_input,
                                                  query_labels,
                                                  query_input,
                                                  loss_fn,
                                                  model_init,
                                                  args,
                                                  mode="eval")

                    mets = logging_metrics(logits.detach().cpu(),
                                           query_labels.detach().cpu())

                    task_loss.append(loss.detach().cpu().item())
                    task_acc.append(mets['acc'])
                    task_f1.append(mets['f1'])

                    task_loss_s.append(loss.detach().cpu().item() /
                                       np.log(dataloader.n_classes))
                    task_acc_s.append(mets['acc'] / (1 / dataloader.n_classes))
                    task_f1_s.append(mets['f1'] / (1 / dataloader.n_classes))

                overall_loss.append(np.mean(task_loss))
                overall_acc.append(np.mean(task_acc))
                overall_f1.append(np.mean(task_f1))

                overall_loss_s.append(np.mean(task_loss_s))
                overall_acc_s.append(np.mean(task_acc_s))
                overall_f1_s.append(np.mean(task_f1_s))

                print(
                    "{:} | Eval  | Episode {:04} | Task {:^20s} | Loss {:5.2f}, Acc {:5.2f}, F1 {:5.2f} | Mem {:5.2f} GB"
                    .format(
                        timer.dt(), episode, task, overall_loss_s[-1]
                        if args['print_scaled'] else overall_loss[-1],
                        overall_acc_s[-1] if args['print_scaled'] else
                        overall_acc[-1], overall_f1_s[-1]
                        if args['print_scaled'] else overall_f1[-1],
                        psutil.Process(os.getpid()).memory_info().rss /
                        1024**3))

                writer.add_scalars('Loss/Eval', {task: overall_loss[-1]},
                                   episode)
                writer.add_scalars('Accuracy/Eval', {task: overall_acc[-1]},
                                   episode)
                writer.add_scalars('F1/Eval', {task: overall_f1[-1]}, episode)

                writer.add_scalars('LossScaled/Eval',
                                   {task: overall_loss_s[-1]}, episode)
                writer.add_scalars('AccuracyScaled/Eval',
                                   {task: overall_acc_s[-1]}, episode)
                writer.add_scalars('F1Scaled/Eval', {task: overall_f1_s[-1]},
                                   episode)

                writer.flush()

            #######################
            # All Tasks Aggregate #
            #######################
            overall_loss = np.mean(overall_loss)
            overall_acc = np.mean(overall_acc)
            overall_f1 = np.mean(overall_f1)

            overall_loss_s = np.mean(overall_loss_s)
            overall_acc_s = np.mean(overall_acc_s)
            overall_f1_s = np.mean(overall_f1_s)

            print(
                "{:} | MACRO-AGG | Eval  | Episode {:04} | Loss {:5.2f}, Acc {:5.2f}, F1 {:5.2f}\n"
                .format(
                    timer.dt(), episode,
                    overall_loss_s if args['print_scaled'] else overall_loss,
                    overall_acc_s if args['print_scaled'] else overall_acc,
                    overall_f1_s if args['print_scaled'] else overall_f1))

            writer.add_scalar('Loss/MacroEval', overall_loss, episode)
            writer.add_scalar('Accuracy/MacroEval', overall_acc, episode)
            writer.add_scalar('F1/MacroEval', overall_f1, episode)

            writer.add_scalar('LossScaled/MacroEval', overall_loss_s, episode)
            writer.add_scalar('AccuracyScaled/MacroEval', overall_acc_s,
                              episode)
            writer.add_scalar('F1Scaled/MacroEval', overall_f1_s, episode)

            writer.flush()

            #####################
            # Best Model Saving #
            #####################
            if overall_acc_s >= best_overall_acc_s:
                for file in os.listdir(checkpoint_save_path):
                    if 'best_model' in file:
                        ep = re.match(r".+macroaccs_\[(.+)\]", file)
                        if float(ep.group(1)):
                            os.remove(os.path.join(checkpoint_save_path, file))

                save_name = "best_model-episode_[{:}]-macroaccs_[{:.2f}].checkpoint".format(
                    episode, overall_acc_s)

                with open(os.path.join(checkpoint_save_path, save_name),
                          'wb') as f:

                    torch.save(model_init.state_dict(), f)

                print(
                    f"New best scaled accuracy. Saving model as {save_name}\n")
                best_overall_acc_s = overall_acc_s
                curr_patience = args['patience']
            else:
                if episode > args['min_episodes']:
                    curr_patience -= 1
                #print(f"Model did not improve with macroaccs_={overall_acc_s}. Patience is now {curr_patience}\n")

            #######################
            # Latest Model Saving #
            #######################
            for file in os.listdir(checkpoint_save_path):
                if 'latest_model' in file:
                    ep = re.match(r".+episode_\[([a-zA-Z0-9\.]+)\].+", file)
                    if ep != None and int(ep.group(1)) <= episode:
                        os.remove(os.path.join(checkpoint_save_path, file))

            save_name = "latest_model-episode_[{:}]-macroaccs_[{:.2f}].checkpoint".format(
                episode, overall_acc_s)

            with open(os.path.join(checkpoint_save_path, save_name),
                      'wb') as f:

                torch.save(model_init.state_dict(), f)

            with open(
                    os.path.join(checkpoint_save_path,
                                 "latest_trainer.pickle"), 'wb') as f:

                pickle.dump(
                    {
                        'episode': episode,
                        'overall_acc_s': overall_acc_s,
                        'best_overall_acc_s': best_overall_acc_s
                    }, f)

            if episode >= args['min_episodes']:
                reduceOnPlateau.step(overall_acc_s)

                curr_lr = meta_optimizer.param_groups[0]['lr']
                if curr_lr < args['min_meta_lr']:
                    print("Patience spent.\nEarly stopping.")
                    raise KeyboardInterrupt

        writer.add_scalar('Meta-lr', meta_optimizer.param_groups[0]['lr'],
                          episode)
def eval(args):
    def _get_dataloader(datasubset,
                        tokenizer,
                        device,
                        args,
                        subset_classes=True):
        """
        Get specific dataloader.

        Args:
            datasubset ([type]): [description]
            tokenizer ([type]): [description]
            device ([type]): [description]
            args ([type]): [description]

        Returns:
            dataloader
        """

        if subset_classes:
            dataloader = StratifiedLoaderwClassesSubset(
                datasubset,
                k=args['k'],
                max_classes=args['max_classes'],
                max_batch_size=args['max_batch_size'],
                tokenizer=tokenizer,
                device=device,
                shuffle=True,
                verbose=False)
        else:
            dataloader = StratifiedLoader(
                datasubset,
                k=args['k'],
                max_batch_size=args['max_batch_size'],
                tokenizer=tokenizer,
                device=device,
                shuffle=True,
                verbose=False)

        return dataloader

    def _adapt_and_fit(support_labels_list, support_input_list,
                       query_labels_list, query_input_list, loss_fn,
                       model_init, args, mode):
        """
        Adapts the init model to a support set and computes loss on query set.

        Args:
            support_labels ([type]): [description]
            support_text ([type]): [description]
            query_labels ([type]): [description]
            query_text ([type]): [description]
            model_init ([type]): [description]
            args
            mode
        """

        #####################
        # Create model_task #
        #####################
        model_init.eval()

        model_task = deepcopy(model_init)
        model_task_optimizer = optim.SGD(model_task.parameters(),
                                         lr=args['inner_lr'])
        model_task.zero_grad()

        #######################
        # Generate prototypes #
        #######################

        with torch.no_grad():
            prototypes = 0.0
            for support_labels, support_input in zip(support_labels_list,
                                                     support_input_list):
                if mode != "baseline":
                    y = model_init(support_input)
                else:
                    y = model_init.encode(support_input)

                labs = torch.sort(torch.unique(support_labels))[0]
                prototypes += torch.stack(
                    [torch.mean(y[support_labels == c], dim=0) for c in labs])

            prototypes = prototypes / len(support_labels_list)

            W_init = 2 * prototypes
            b_init = -torch.norm(prototypes, p=2, dim=1)**2

        W_task, b_task = W_init.detach(), b_init.detach()
        W_task.requires_grad, b_task.requires_grad = True, True

        #################
        # Adapt to data #
        #################
        for _ in range(args['n_inner']):
            for support_labels, support_input in zip(support_labels_list,
                                                     support_input_list):
                if mode != "baseline":
                    y = model_task(support_input)
                else:
                    y = model_task.encode(support_input)

                logits = F.linear(y, W_task, b_task)

                inner_loss = loss_fn(logits, support_labels)

                W_task_grad, b_task_grad = torch.autograd.grad(
                    inner_loss, [W_task, b_task], retain_graph=True)

                inner_loss.backward()

                if args['clip_val'] > 0:
                    torch.nn.utils.clip_grad_norm_(model_task.parameters(),
                                                   args['clip_val'])

                model_task_optimizer.step()

                W_task = W_task - args['output_lr'] * W_task_grad
                b_task = b_task - args['output_lr'] * b_task_grad

        #########################
        # Validate on query set #
        #########################
        logits_list, outer_loss_list = [], []
        for query_labels, query_input in zip(query_labels_list,
                                             query_input_list):
            with torch.no_grad():
                if mode != "baseline":
                    y = model_task(query_input)
                else:
                    y = model_task.encode(query_input)

                logits = F.linear(y, W_task, b_task)

                outer_loss = loss_fn(logits, query_labels)

                logits_list.append(logits)
                outer_loss_list.append(outer_loss)

        return logits_list, outer_loss_list

    #######################
    # Logging Directories #
    #######################
    log_dir = os.path.join(args['checkpoint_path'], args['version'])

    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(os.path.join(log_dir, args['save_version']), exist_ok=True)
    os.makedirs(os.path.join(log_dir, 'checkpoint'), exist_ok=True)
    #print(f"Saving models and logs to {log_dir}")

    checkpoint_save_path = os.path.join(log_dir, 'checkpoint')

    if args['mode'] != "baseline":
        with open(os.path.join("./", checkpoint_save_path, "hparams.pickle"),
                  mode='rb+') as f:
            hparams = pickle.load(f)
    else:
        with open(os.path.join("./", args['checkpoint_path'],
                               "hparams.pickle"),
                  mode='rb+') as f:
            hparams = pickle.load(f)

    ##########################
    # Device, Logging, Timer #
    ##########################

    set_seed(args['seed'])

    timer = Timer()

    device = torch.device('cuda' if (
        torch.cuda.is_available() and args['gpu']) else 'cpu')

    # Build the tensorboard writer
    writer = SummaryWriter(os.path.join(log_dir, args['save_version']))

    ###################
    # Load in dataset #
    ###################
    print("Data Prep")
    dataset = meta_dataset(include=args['include'], verbose=True)
    dataset.prep(text_tokenizer=manual_tokenizer)
    print("")

    ####################
    # Init models etc. #
    ####################
    if args['mode'] != "baseline":
        model_init = SeqTransformer(hparams)
        tokenizer = AutoTokenizer.from_pretrained(hparams['encoder_name'])
    else:
        model_init = CustomBERT(num_classes=task_label_dict[args['version']])
        tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

    tokenizer.add_special_tokens({'additional_special_tokens': specials()})
    model_init.encoder.model.resize_token_embeddings(len(tokenizer.vocab))

    for file in os.listdir(checkpoint_save_path):
        if 'best_model' in file:
            fp = os.path.join(checkpoint_save_path, file)
            with open(fp, mode='rb+') as f:
                print(f"Found pre-trained file at {fp}")
                if args['mode'] != "baseline":
                    model_init.load_state_dict(
                        torch.load(f, map_location=device))

                    for name, param in model_init.encoder.model.named_parameters(
                    ):
                        transformer_layer = re.search(
                            "(?:encoder\.layer\.)([0-9]+)", name)
                        if transformer_layer and (int(
                                transformer_layer.group(1)) > args['nu']):
                            param.requires_grad = True
                        elif 'pooler' in name:
                            param.requires_grad = False
                        elif args['nu'] < 0:
                            param.requires_grad = True
                        else:
                            param.requires_grad = False
                else:
                    model_init.load_state_dict(
                        torch.load(f, map_location=device)["bert_state_dict"])

    model_init = model_init.to(device)

    loss_fn = nn.CrossEntropyLoss()

    ##############
    # Evaluation #
    ##############

    results_dict = defaultdict(dict)

    for split in args['splits']:

        overall_loss, overall_acc, overall_f1 = [], [], []
        overall_loss_s, overall_acc_s, overall_f1_s = [], [], []

        ###################
        # Individual Task #
        ###################
        for task in dataset.lens.keys():
            datasubset = dataset.datasets[task][split]

            task_loss, task_acc, task_f1 = [], [], []
            task_loss_s, task_acc_s, task_f1_s = [], [], []
            for _ in range(args['n_eval_per_task']):
                dataloader = _get_dataloader(
                    datasubset,
                    tokenizer,
                    device,
                    args,
                    subset_classes=args['subset_classes'])

                total_size = args['k'] * dataloader.n_classes
                n_sub_batches = total_size / args['max_batch_size']
                reg_k = int(args['k'] // n_sub_batches)
                left_over = args['k'] * dataloader.n_classes - \
                    int(n_sub_batches) * reg_k * dataloader.n_classes
                last_k = int(left_over / dataloader.n_classes)


                support_labels_list, support_input_list, query_labels_list, query_input_list = [], [], [], []

                dataloader.k = reg_k
                for _ in range(int(n_sub_batches)):

                    support_labels, support_text, query_labels, query_text = next(
                        dataloader)

                    support_labels_list.append(support_labels)
                    support_input_list.append(support_text)
                    query_labels_list.append(query_labels)
                    query_input_list.append(query_text)

                if last_k > 0.0:
                    dataloader.k = last_k
                    support_labels, support_text, query_labels, query_text = next(
                        dataloader)

                    support_labels_list.append(support_labels)
                    support_input_list.append(support_text)
                    query_labels_list.append(query_labels)
                    query_input_list.append(query_text)

                logits_list, loss_list = _adapt_and_fit(
                    support_labels_list, support_input_list, query_labels_list,
                    query_input_list, loss_fn, model_init, hparams,
                    args['mode'])

                for logits, query_labels, loss in zip(logits_list,
                                                      query_labels_list,
                                                      loss_list):
                    mets = logging_metrics(logits.detach().cpu(),
                                           query_labels.detach().cpu())

                    task_loss.append(loss.detach().cpu().item())
                    task_acc.append(mets['acc'])
                    task_f1.append(mets['f1'])

                    task_loss_s.append(loss.detach().cpu().item() /
                                       np.log(dataloader.n_classes))
                    task_acc_s.append(mets['acc'] / (1 / dataloader.n_classes))
                    task_f1_s.append(mets['f1'] / (1 / dataloader.n_classes))

            overall_loss.append(np.mean(task_loss))
            overall_acc.append(np.mean(task_acc))
            overall_f1.append(np.mean(task_f1))

            overall_loss_s.append(np.mean(task_loss_s))
            overall_acc_s.append(np.mean(task_acc_s))
            overall_f1_s.append(np.mean(task_f1_s))

            print(
                "{:} | Eval  | Split {:^8s} | Task {:^20s} | Loss {:5.2f} ({:4.2f}), Acc {:5.2f} ({:4.2f}), F1 {:5.2f} ({:4.2f}) | Mem {:5.2f} GB"
                .format(
                    timer.dt(), split, task, overall_loss_s[-1]
                    if args['print_scaled'] else overall_loss[-1],
                    np.std(task_loss_s) if args['print_scaled'] else
                    np.std(task_loss), overall_acc_s[-1]
                    if args['print_scaled'] else overall_acc[-1],
                    np.std(task_acc_s) if args['print_scaled'] else
                    np.std(task_acc), overall_f1_s[-1]
                    if args['print_scaled'] else overall_f1[-1],
                    np.std(task_f1_s)
                    if args['print_scaled'] else np.std(task_f1),
                    psutil.Process(os.getpid()).memory_info().rss / 1024**3))

            writer.add_scalars(f'Loss/{split}', {task: overall_loss[-1]}, 0)
            writer.add_scalars(f'Accuracy/{split}', {task: overall_acc[-1]}, 0)
            writer.add_scalars(f'F1/{split}', {task: overall_f1[-1]}, 0)

            writer.add_scalars(f'LossScaled/{split}',
                               {task: overall_loss_s[-1]}, 0)
            writer.add_scalars(f'AccuracyScaled/{split}',
                               {task: overall_acc_s[-1]}, 0)
            writer.add_scalars(f'F1Scaled/{split}', {task: overall_f1_s[-1]},
                               0)

            writer.flush()

            results_dict[task][split] = {
                "loss":
                "{:.2f} ({:.2f})".format(overall_loss[-1], np.std(task_loss)),
                "acc":
                "{:.2f} ({:.2f})".format(overall_acc[-1], np.std(task_acc)),
                "f1":
                "{:.2f} ({:.2f})".format(overall_f1[-1], np.std(task_f1)),
                "loss_scaled":
                "{:.2f} ({:.2f})".format(overall_loss_s[-1],
                                         np.std(task_loss_s)),
                "acc_scaled":
                "{:.2f} ({:.2f})".format(overall_acc_s[-1],
                                         np.std(task_acc_s)),
                "f1_scaled":
                "{:.2f} ({:.2f})".format(overall_f1_s[-1], np.std(task_f1_s)),
            }

        #######################
        # All Tasks Aggregate #
        #######################
        overall_loss = np.mean(overall_loss)
        overall_acc = np.mean(overall_acc)
        overall_f1 = np.mean(overall_f1)

        overall_loss_s = np.mean(overall_loss_s)
        overall_acc_s = np.mean(overall_acc_s)
        overall_f1_s = np.mean(overall_f1_s)

        print(
            "{:} | MACRO-AGG | Eval  | Split {:^8s} | Loss {:5.2f}, Acc {:5.2f}, F1 {:5.2f}\n"
            .format(timer.dt(), split,
                    overall_loss_s if args['print_scaled'] else overall_loss,
                    overall_acc_s if args['print_scaled'] else overall_acc,
                    overall_f1_s if args['print_scaled'] else overall_f1))

        writer.add_scalar(f'Loss/Macro{split}', overall_loss, 0)
        writer.add_scalar(f'Accuracy/Macro{split}', overall_acc, 0)
        writer.add_scalar(f'F1/Macro{split}', overall_f1, 0)

        writer.add_scalar(f'LossScaled/Macro{split}', overall_loss_s, 0)
        writer.add_scalar(f'AccuracyScaled/Macro{split}', overall_acc_s, 0)
        writer.add_scalar(f'F1Scaled/Macro{split}', overall_f1_s, 0)

        writer.flush()

    with open(os.path.join(log_dir, args['save_version'], 'results.pickle'),
              'wb+') as file:
        pickle.dump(results_dict, file)