def evaluate(
        self, eps_generator: typing.Union[OmniglotLoader, ImageFolderGenerator]
    ) -> typing.List[float]:
        """Evaluate the performance
        """
        print('Evaluation is started.\n')
        # load model
        model = self.load_model(resume_epoch=self.config['resume_epoch'],
                                hyper_net_class=self.hyper_net_class,
                                eps_generator=eps_generator)

        # get list of episode names, each episode name consists of classes
        eps = get_episodes(episode_file_path=self.config['episode_file'])

        accuracies = [None] * len(eps)

        for i, eps_name in enumerate(eps):
            eps_data = eps_generator.generate_episode(episode_name=eps_name)
            # split data into train and validation
            xt, yt, xv, yv = train_val_split(X=eps_data,
                                             k_shot=self.config['k_shot'],
                                             shuffle=True)

            # move data to GPU (if there is a GPU)
            x_t = torch.from_numpy(xt).float().to(self.config['device'])
            y_t = torch.tensor(yt,
                               dtype=torch.long,
                               device=self.config['device'])
            x_v = torch.from_numpy(xv).float().to(self.config['device'])
            y_v = torch.tensor(yv,
                               dtype=torch.long,
                               device=self.config['device'])

            _, logits = self.adapt_and_predict(model=model,
                                               x_t=x_t,
                                               y_t=y_t,
                                               x_v=x_v,
                                               y_v=None)

            # initialize y_prediction
            y_pred = torch.zeros(size=(y_v.shape[0], len(eps_data)),
                                 dtype=torch.float,
                                 device=self.config['device'])
            for logits_ in logits:
                y_pred += torch.softmax(input=logits_, dim=1)
            y_pred /= len(logits)

            accuracies[i] = (y_pred.argmax(dim=1) == y_v).float().mean().item()

            sys.stdout.write('\033[F')
            print(i + 1)

        acc_mean = np.mean(a=accuracies)
        acc_std = np.std(a=accuracies)
        print('\nAccuracy = {0:.2f} +/- {1:.2f}\n'.format(
            acc_mean * 100, 1.96 * acc_std / np.sqrt(len(accuracies)) * 100))
        return accuracies
Exemplo n.º 2
0
def evaluate() -> None:
    assert resume_epoch > 0

    acc = []

    if (num_episodes is None) and (episode_file is None):
        raise ValueError('Expect exactly one of num_episodes and episode_file to be not None, receive both are None.')

    # load model
    net, _, _ = load_model(epoch_id=resume_epoch)
    net.eval()

    episodes = get_episodes(episode_file_path=episode_file)
    if None in episodes:
        episodes = [None] * num_episodes

    file_acc = os.path.join(logdir, 'accuracy.txt')
    f_acc = open(file=file_acc, mode='w')
    try:
        for i, episode_name in enumerate(episodes):
            X = eps_generator.generate_episode(episode_name=episode_name)
                    
            # split into train and validation
            xt, yt, xv, yv = train_val_split(X=X, k_shot=k_shot, shuffle=True)

            # move data to gpu
            x_t = torch.from_numpy(xt).float().to(device)
            y_t = torch.tensor(yt, dtype=torch.long, device=device)
            x_v = torch.from_numpy(xv).float().to(device)
            y_v = torch.tensor(yv, dtype=torch.long, device=device)

            # adapt on the support data
            fnet = adapt_to_episode(x=x_t, y=y_t, net=net)

            # evaluate on the query data
            logits_v = fnet(x_v)
            episode_acc = (logits_v.argmax(dim=1) == y_v).sum().item() / (len(X) * v_shot)

            acc.append(episode_acc)
            f_acc.write('{}\n'.format(episode_acc))

            sys.stdout.write('\033[F')
            print(i)
    except:
        pass
    else:
        pass
    finally:
        f_acc.close()
    
    mean = np.mean(a=acc)
    std = np.std(a=acc)
    n = len(acc)
    print('Accuracy = {0:.4f} +/- {1:.4f}'.format(mean, 1.96 * std / np.sqrt(n)))
Exemplo n.º 3
0
def evaluate() -> None:
    assert resume_epoch > 0

    acc = []

    # load model
    net, _, _ = load_model(epoch_id=resume_epoch)
    net.eval()

    episodes = get_episodes(episode_file_path=episode_file)
    if None in episodes:
        episodes = [None] * num_episodes

    file_acc = os.path.join(logdir, 'accuracy.txt')
    f_acc = open(file=file_acc, mode='w')
    try:
        with torch.no_grad():
            for i, episode_ in enumerate(episodes):
                X = eps_generator.generate_episode(episode_name=episode_)
                        
                # split into train and validation
                xt, yt, xv, yv = train_val_split(X=X, k_shot=k_shot, shuffle=True)

                # move data to gpu
                x_t = torch.from_numpy(xt).float().to(device)
                y_t = torch.tensor(yt, dtype=torch.long, device=device)
                x_v = torch.from_numpy(xv).float().to(device)
                y_v = torch.tensor(yv, dtype=torch.long, device=device)

                # adapt on the support data
                z_prototypes = adapt_to_episode(x=x_t, y=y_t, net=net)

                # evaluate on the query data
                z_v = net.forward(x_v)
                distance_matrix = euclidean_distance(matrixN=z_v, matrixM=z_prototypes)
                logits_v = -distance_matrix
                episode_acc = (logits_v.argmax(dim=1) == y_v).sum().item() / (len(X) * v_shot)

                acc.append(episode_acc)
                f_acc.write('{0}\n'.format(episode_acc))

                sys.stdout.write('\033[F')
                print(i)
    except:
        pass
    else:
        pass
    finally:
        f_acc.close()
    
    mean = np.mean(a=acc)
    std = np.std(a=acc)
    n = len(acc)
    print('Accuracy = {0:.4f} +/- {1:.4f}'.format(mean, 1.96 * std / np.sqrt(n)))
Exemplo n.º 4
0
def train() -> None:
    """Train
    
    Args:

    Returns:
    """

    try:
        # parse training parameters
        meta_lr = args.meta_lr
        minibatch = args.minibatch
        minibatch_print = np.lcm(minibatch, 100)
        decay_lr = args.decay_lr

        num_episodes_per_epoch = args.num_episodes_per_epoch
        num_epochs = args.num_epochs

        # initialize/load model
        net, meta_optimizer, schdlr = load_model(epoch_id=resume_epoch, meta_lr=meta_lr, decay_lr=decay_lr)
        
        # zero grad
        meta_optimizer.zero_grad()

        # get episode list if not None -> generator of episode names, each episode name consists of classes
        episodes = get_episodes(episode_file_path=episode_file)

        # initialize a tensorboard summary writer for logging
        tb_writer = SummaryWriter(
            log_dir=logdir,
            purge_step=resume_epoch * num_episodes_per_epoch // minibatch_print if resume_epoch > 0 else None
        )

        for epoch_id in range(resume_epoch, resume_epoch + num_epochs, 1):
            episode_count = 0
            loss_monitor = 0
            
            while (episode_count < num_episodes_per_epoch):
                # get episode from the given csv file, or just return None
                episode_name = random.sample(population=episodes, k=1)[0]

                X = eps_generator.generate_episode(episode_name=episode_name)
                
                # split into train and validation
                xt, yt, xv, yv = train_val_split(X=X, k_shot=k_shot, shuffle=True)

                # move data to gpu
                x_t = torch.from_numpy(xt).float().to(device)
                y_t = torch.tensor(yt, dtype=torch.long, device=device)
                x_v = torch.from_numpy(xv).float().to(device)
                y_v = torch.tensor(yv, dtype=torch.long, device=device)

                # adapt on the support data
                fnet = adapt_to_episode(x=x_t, y=y_t, net=net)

                # evaluate on the query data
                logits_v = fnet.forward(x_v)
                cls_loss = torch.nn.functional.cross_entropy(input=logits_v, target=y_v)
                loss_monitor += cls_loss.item()

                cls_loss = cls_loss / minibatch
                cls_loss.backward()

                episode_count += 1

                # update the meta-model
                if (episode_count % minibatch == 0):
                    meta_optimizer.step()
                    meta_optimizer.zero_grad()

                # monitor losses
                if (episode_count % minibatch_print == 0):
                    loss_monitor /= minibatch_print
                    global_step = (epoch_id * num_episodes_per_epoch + episode_count) // minibatch_print
                    tb_writer.add_scalar(
                        tag='Loss',
                        scalar_value=loss_monitor,
                        global_step=global_step
                    )
                    loss_monitor = 0

            # decay learning rate
            schdlr.step()

            # save model
            checkpoint = {
                'net_state_dict': net.state_dict(),
                'op_state_dict': meta_optimizer.state_dict(),
                'lr_schdlr_state_dict': schdlr.state_dict()
            }
            checkpoint_filename = 'Epoch_{0:d}.pt'.format(epoch_id + 1)
            torch.save(checkpoint, os.path.join(logdir, checkpoint_filename))
            checkpoint = 0
            print('SAVING parameters into {0:s}\n'.format(checkpoint_filename))

    except KeyboardInterrupt:
        pass
    else:
        pass
    finally:
        print('\nClose tensorboard summary writer')
        tb_writer.close()
Exemplo n.º 5
0
def evaluate(hyper_net_cls, get_f_base_net_fn: typing.Callable,
             adapt_to_episode: typing.Callable,
             get_accuracy_fn: typing.Callable) -> None:
    """Evaluation
    """
    acc = []

    # initialize/load model
    hyper_net, base_net, _, _ = load_model(hyper_net_cls=hyper_net_cls,
                                           epoch_id=config['resume_epoch'],
                                           meta_lr=config['meta_lr'],
                                           decay_lr=config['decay_lr'])

    hyper_net.eval()
    base_net.eval()

    # get list of episode names, each episode name consists of classes
    episodes = get_episodes(episode_file_path=config['episode_file'],
                            num_episodes=config['num_episodes'])

    try:
        acc_file = open(file=os.path.join(logdir, 'accuracy.txt'), mode='w')
        for i, episode_name in enumerate(episodes):
            X = eps_generator.generate_episode(episode_name=episode_name)

            # split into train and validation
            xt, yt, xv, yv = train_val_split(X=X,
                                             k_shot=config['k_shot'],
                                             shuffle=True)

            # move data to gpu
            x_t = torch.from_numpy(xt).float().to(device)
            y_t = torch.tensor(yt, dtype=torch.long, device=device)
            x_v = torch.from_numpy(xv).float().to(device)
            y_v = torch.tensor(yv, dtype=torch.long, device=device)

            # -------------------------
            # functional base network
            # -------------------------
            f_base_net = get_f_base_net_fn(base_net=base_net)

            # -------------------------
            # adapt on the support data
            # -------------------------
            f_hyper_net = adapt_to_episode(x=x_t,
                                           y=y_t,
                                           hyper_net=hyper_net,
                                           f_base_net=f_base_net)

            # -------------------------
            # accuracy
            # -------------------------
            acc_temp = get_accuracy_fn(x=x_v,
                                       y=y_v,
                                       f_hyper_net=f_hyper_net,
                                       f_base_net=f_base_net)

            acc.append(acc_temp)
            acc_file.write('{}\n'.format(acc_temp))

            sys.stdout.write('\033[F')
            print(i)
    finally:
        acc_file.close()

    acc_mean = np.mean(acc)
    acc_std = np.std(acc)
    print('Accuracy = {} +/- {}'.format(
        acc_mean, 1.96 * acc_std / np.sqrt(len(episodes))))

    return None
Exemplo n.º 6
0
def train(hyper_net_cls, get_f_base_net_fn: typing.Callable,
          adapt_to_episode: typing.Callable,
          loss_on_query_fn: typing.Callable) -> None:
    """Base method used for training

    Args:

    """
    # initialize/load model
    hyper_net, base_net, meta_opt, schdlr = load_model(
        hyper_net_cls=hyper_net_cls,
        epoch_id=config['resume_epoch'],
        meta_lr=config['meta_lr'],
        decay_lr=config['decay_lr'])

    # zero grad
    meta_opt.zero_grad()

    # get list of episode names, each episode name consists of classes
    episodes = get_episodes(episode_file_path=config['episode_file'])

    # initialize a tensorboard summary writer for logging
    tb_writer = SummaryWriter(
        log_dir=logdir,
        purge_step=config['resume_epoch'] * config['num_episodes_per_epoch'] //
        minibatch_print if config['resume_epoch'] > 0 else None)

    try:
        for epoch_id in range(config['resume_epoch'],
                              config['resume_epoch'] + config['num_epochs'],
                              1):
            episode_count = 0
            loss_monitor = 0
            # kl_div_monitor = 0

            while (episode_count < config['num_episodes_per_epoch']):
                # get episode from the given csv file, or just return None
                episode_name = random.sample(population=episodes, k=1)[0]

                X = eps_generator.generate_episode(episode_name=episode_name)

                # split into train and validation
                xt, yt, xv, yv = train_val_split(X=X,
                                                 k_shot=config['k_shot'],
                                                 shuffle=True)

                # move data to gpu
                x_t = torch.from_numpy(xt).float().to(device)
                y_t = torch.tensor(yt, dtype=torch.long, device=device)
                x_v = torch.from_numpy(xv).float().to(device)
                y_v = torch.tensor(yv, dtype=torch.long, device=device)

                # -------------------------
                # functional base network
                # -------------------------
                f_base_net = get_f_base_net_fn(base_net=base_net)

                # -------------------------
                # adapt on the support data
                # -------------------------
                f_hyper_net = adapt_to_episode(x=x_t,
                                               y=y_t,
                                               hyper_net=hyper_net,
                                               f_base_net=f_base_net)

                # -------------------------
                # loss on query data
                # -------------------------
                loss_meta = loss_on_query_fn(x=x_v,
                                             y=y_v,
                                             f_hyper_net=f_hyper_net,
                                             f_base_net=f_base_net,
                                             hyper_net=hyper_net)

                if torch.isnan(loss_meta):
                    raise ValueError('Validation loss is NaN.')

                loss_meta = loss_meta / config['minibatch']
                loss_meta.backward()

                # monitoring validation loss
                loss_monitor += loss_meta.item()
                # kl_div_monitor += kl_loss.item()

                episode_count += 1
                # update the meta-model
                if (episode_count % config['minibatch'] == 0):
                    # torch.nn.utils.clip_grad_norm_(parameters=hyper_net.parameters(), max_norm=10)
                    meta_opt.step()
                    meta_opt.zero_grad()

                # monitor losses
                if (episode_count % minibatch_print == 0):
                    loss_monitor /= minibatch_print
                    # kl_div_monitor /= minibatch_print

                    # print('{}, {}'.format(loss_monitor, kl_div_monitor))
                    # print(loss_monitor)

                    global_step = (epoch_id * config['num_episodes_per_epoch']
                                   + episode_count) // minibatch_print
                    tb_writer.add_scalar(tag='Loss',
                                         scalar_value=loss_monitor,
                                         global_step=global_step)

                    loss_monitor = 0
                    # kl_div_monitor = 0

            # decay learning rate
            schdlr.step()

            # save model
            checkpoint = {
                'hyper_net_state_dict': hyper_net.state_dict(),
                'op_state_dict': meta_opt.state_dict(),
                'lr_schdlr_state_dict': schdlr.state_dict()
            }
            checkpoint_filename = 'Epoch_{0:d}.pt'.format(epoch_id + 1)
            torch.save(checkpoint, os.path.join(logdir, checkpoint_filename))
            checkpoint = 0
            print('SAVING parameters into {0:s}\n'.format(checkpoint_filename))
    finally:
        print('\nClose tensorboard summary writer')
        tb_writer.close()

    return None
    def train(
        self, eps_generator: typing.Union[OmniglotLoader,
                                          ImageFolderGenerator]) -> None:
        """Train meta-learning model

        Args:
            eps_generator: the generator that generate episodes/tasks
        """
        print('Training is started.\nLog is stored at {0:s}.\n'.format(
            self.config['logdir']))

        # initialize/load model. Please see the load_model method implemented in each specific class for further information about the model
        model = self.load_model(resume_epoch=self.config['resume_epoch'],
                                hyper_net_class=self.hyper_net_class,
                                eps_generator=eps_generator)
        model[-1].zero_grad()

        # get list of episode names, each episode name consists of classes
        eps = get_episodes(episode_file_path=self.config['episode_file'])

        # initialize a tensorboard summary writer for logging
        tb_writer = SummaryWriter(log_dir=self.config['logdir'],
                                  purge_step=self.config['resume_epoch'] *
                                  self.config['num_episodes_per_epoch'] //
                                  self.config['minibatch_print']
                                  if self.config['resume_epoch'] > 0 else None)

        try:
            for epoch_id in range(
                    self.config['resume_epoch'],
                    self.config['resume_epoch'] + self.config['num_epochs'],
                    1):
                loss_monitor = 0.
                KL_monitor = 0.
                for eps_count in range(self.config['num_episodes_per_epoch']):
                    # -------------------------
                    # get eps from the given csv file or just random (None)
                    # -------------------------
                    eps_name = random.sample(population=eps, k=1)[0]

                    # -------------------------
                    # episode data
                    # -------------------------
                    eps_data = eps_generator.generate_episode(
                        episode_name=eps_name)

                    # split data into train and validation
                    xt, yt, xv, yv = train_val_split(
                        X=eps_data, k_shot=self.config['k_shot'], shuffle=True)

                    # move data to GPU (if there is a GPU)
                    x_t = torch.from_numpy(xt).float().to(
                        self.config['device'])
                    y_t = torch.tensor(yt,
                                       dtype=torch.long,
                                       device=self.config['device'])
                    x_v = torch.from_numpy(xv).float().to(
                        self.config['device'])
                    y_v = torch.tensor(yv,
                                       dtype=torch.long,
                                       device=self.config['device'])

                    # -------------------------
                    # adapt and predict the support data
                    # -------------------------
                    f_hyper_net, logits = self.adapt_and_predict(model=model,
                                                                 x_t=x_t,
                                                                 y_t=y_t,
                                                                 x_v=x_v,
                                                                 y_v=y_v)
                    loss_v = 0.
                    for logits_ in logits:
                        loss_v_temp = torch.nn.functional.cross_entropy(
                            input=logits_, target=y_v)
                        loss_v = loss_v + loss_v_temp
                    loss_v = loss_v / len(logits)
                    loss_monitor += loss_v.item()  # monitor validation loss

                    # calculate KL divergence
                    KL_div = self.KL_divergence(model=model,
                                                f_hyper_net=f_hyper_net)
                    KL_monitor += KL_div.item() if isinstance(
                        KL_div,
                        torch.Tensor) else KL_div  # monitor KL divergence

                    # extra loss applicable for ABML only
                    loss_extra = self.loss_extra(model=model,
                                                 f_hyper_net=f_hyper_net,
                                                 x_t=x_t,
                                                 y_t=y_t)

                    # accumulate KL divergence to loss
                    loss_v = loss_v + loss_extra + self.config[
                        'KL_weight'] * KL_div
                    loss_v = loss_v / self.config['minibatch']

                    # calculate gradients w.r.t. hyper_net's parameters
                    loss_v.backward()

                    # update meta-parameters
                    if ((eps_count + 1) % self.config['minibatch'] == 0):
                        loss_prior = self.loss_prior(model=model)
                        if hasattr(loss_prior, 'requires_grad'):
                            loss_prior.backward()

                        model[-1].step()
                        model[-1].zero_grad()

                        # monitoring
                        if (eps_count +
                                1) % self.config['minibatch_print'] == 0:
                            loss_monitor /= self.config['minibatch_print']
                            KL_monitor = KL_monitor * self.config[
                                'minibatch'] / self.config['minibatch_print']

                            # calculate step for Tensorboard Summary Writer
                            global_step = (
                                epoch_id *
                                self.config['num_episodes_per_epoch'] +
                                eps_count +
                                1) // self.config['minibatch_print']

                            tb_writer.add_scalar(tag='Cls loss',
                                                 scalar_value=loss_monitor,
                                                 global_step=global_step)
                            tb_writer.add_scalar(tag='KL divergence',
                                                 scalar_value=KL_monitor,
                                                 global_step=global_step)

                            # reset monitoring variables
                            loss_monitor = 0.
                            KL_monitor = 0.

                # save model
                checkpoint = {
                    'hyper_net_state_dict': model[0].state_dict(),
                    'opt_state_dict': model[-1].state_dict()
                }
                checkpoint_path = os.path.join(
                    self.config['logdir'],
                    'Epoch_{0:d}.pt'.format(epoch_id + 1))
                torch.save(obj=checkpoint, f=checkpoint_path)
                print('State dictionaries are saved into {0:s}\n'.format(
                    checkpoint_path))

            print('Training is completed.')
        finally:
            print('\nClose tensorboard summary writer')
            tb_writer.close()

        return None
Exemplo n.º 8
0
    def evaluate(
        self, eps_generator: typing.Union[OmniglotLoader,
                                          ImageFolderGenerator]) -> None:
        """Evaluate the performance
        """
        print('Evaluation is started.\n')
        # load model
        model = self.load_model(resume_epoch=self.config['resume_epoch'],
                                hyper_net_class=self.hyper_net_class,
                                eps_generator=eps_generator)

        # get list of episode names, each episode name consists of classes
        eps = get_episodes(episode_file_path=self.config['episode_file'],
                           num_episodes=1000)

        ece, roc_auc, ll = 0., 0., 0.
        correct, total = 0., 0.
        for i, eps_name in enumerate(eps):
            x_t, y_t, x_v, y_v = eps_generator.generate_episode(
                episode_name=eps_name)
            x_t, y_t, x_v, y_v = x_t.cuda(), y_t.cuda(), x_v.cuda(), y_v.cuda()
            # split data into train and validation

            # grid = torchvision.utils.make_grid(x_t, nrow=5)
            # torchvision.utils.save_image(grid, f"train-{i}.png")

            # grid = torchvision.utils.make_grid(x_v, nrow=90)
            # torchvision.utils.save_image(grid, f"test-{i}.png")

            # if i == 5:
            #     exit()

            # print(x_v.size())
            _, logits = self.adapt_and_predict(model=model,
                                               x_t=x_t,
                                               y_t=y_t,
                                               x_v=x_v,
                                               y_v=None)

            # initialize y_prediction
            y_pred = torch.zeros((y_v.shape[0], self.config["n_way"]),
                                 device=x_v.device)
            for logits_ in logits:
                y_pred += torch.softmax(input=logits_, dim=1)
            y_pred /= len(logits)
            # exit()

            correct += (y_pred.argmax(dim=1) == y_v).sum()
            total += y_v.numel()

            e, _, _ = ece_yhat_only(100, y_v, y_pred, device=y_v.device)
            ece += e

            yhot = OneHotEncoder(categories='auto').fit_transform(
                y_v.cpu().numpy().reshape(-1, 1)).toarray()
            roc_auc += roc_auc_score(yhot,
                                     y_pred.detach().cpu().numpy(),
                                     multi_class="ovr")

            ll += y_pred[torch.arange(y_v.size(0)), y_v].detach().cpu().sum()

            sys.stdout.write('\033[F')
            print(i + 1)  # , " ll: ", -np.log(ll / (i + 1)))

        acc = correct / total
        ece = ece / (i + 1)
        roc_auc = roc_auc / (i + 1)
        nll = -np.log(ll) + np.log(total)

        files = [
            "{}_acc_{}.txt", "{}_ece_{}.txt", "{}_auroc_{}.txt",
            "{}_nll_{}.txt"
        ]
        metrics = [acc, ece, roc_auc, nll]
        for fl, mtrc in zip(files, metrics):
            if self.config["corrupt"]:
                extra = "corrupt"
            elif self.config["ood_test"]:
                extra = "oodtest"
            else:
                extra = "plain"

            path = os.path.join(self.config['logdir'],
                                fl.format(extra, self.config['run']))
            with open(path, "a+") as f:
                f.write("{}\n".format(mtrc))

        print(
            '\nAccuracy: {0:.2f}\nECE: {1:.2f}\nAUROC: {2:.2f}\nNLL: {3:.2f}'.
            format(acc * 100, ece, roc_auc, nll))
Exemplo n.º 9
0
    def train(
        self, eps_generator: typing.Union[OmniglotLoader,
                                          ImageFolderGenerator]) -> None:
        """Train meta-learning model

        Args:
            eps_generator: the generator that generate episodes/tasks
        """
        print('Training is started.\nLog is stored at {0:s}.\n'.format(
            self.config['logdir']))

        # initialize/load model. Please see the load_model method implemented in each specific class for further information about the model
        model = self.load_model(resume_epoch=self.config['resume_epoch'],
                                hyper_net_class=self.hyper_net_class,
                                eps_generator=eps_generator)
        model[-1].zero_grad()

        # get list of episode names, each episode name consists of classes
        eps = get_episodes(episode_file_path=self.config['episode_file'])

        # initialize a tensorboard summary writer for logging
        tb_writer = SummaryWriter(log_dir=self.config['logdir'],
                                  purge_step=self.config['resume_epoch'] *
                                  self.config['num_episodes_per_epoch'] //
                                  self.config['minibatch_print']
                                  if self.config['resume_epoch'] > 0 else None)

        try:
            for epoch_id in range(
                    self.config['resume_epoch'],
                    self.config['resume_epoch'] + self.config['num_epochs'],
                    1):
                print(f"Starting epoch: {epoch_id}")
                loss_monitor = 0.
                KL_monitor = 0.
                correct, total = 0., 0.
                for eps_count in range(self.config['num_episodes_per_epoch']):
                    # -------------------------
                    # get eps from the given csv file or just random (None)
                    # -------------------------
                    eps_name = random.sample(population=eps, k=1)[0]

                    # -------------------------
                    # episode data
                    # -------------------------
                    x_t, y_t, x_v, y_v = eps_generator.generate_episode(
                        episode_name=eps_name)
                    x_t, y_t, x_v, y_v = x_t.cuda(), y_t.cuda(), x_v.cuda(
                    ), y_v.cuda()

                    # -------------------------
                    # adapt and predict the support data
                    # -------------------------

                    f_hyper_net, logits = self.adapt_and_predict(model=model,
                                                                 x_t=x_t,
                                                                 y_t=y_t,
                                                                 x_v=x_v,
                                                                 y_v=y_v)

                    loss_v = 0.
                    for logits_ in logits:
                        loss_v_temp = torch.nn.functional.cross_entropy(
                            input=logits_, target=y_v)
                        loss_v = loss_v + loss_v_temp
                    loss_v = loss_v / len(logits)
                    loss_monitor += loss_v.item()  # monitor validation loss

                    correct += (torch.stack(logits).mean(dim=0).argmax(
                        dim=-1) == y_v).sum().item()
                    total += y_v.numel()

                    # calculate KL divergence
                    KL_div = self.KL_divergence(model=model,
                                                f_hyper_net=f_hyper_net)
                    KL_monitor += KL_div.item() if isinstance(
                        KL_div,
                        torch.Tensor) else KL_div  # monitor KL divergence

                    # extra loss applicable for ABML only
                    loss_extra = self.loss_extra(model=model,
                                                 f_hyper_net=f_hyper_net,
                                                 x_t=x_t,
                                                 y_t=y_t)

                    # accumulate KL divergence to loss
                    loss_v = loss_v + loss_extra + self.config[
                        'KL_weight'] * KL_div
                    loss_v = loss_v / self.config['minibatch']

                    # calculate gradients w.r.t. hyper_net's parameters
                    loss_v.backward()

                    # sys.stdout.write('\033[F')
                    # print(f"correct: {(correct / total):.4f}")  # , " ll: ", -np.log(ll / (i + 1)))

                    self.config['iters'] += 1

                    # update meta-parameters
                    if ((eps_count + 1) % self.config['minibatch'] == 0):
                        loss_prior = self.loss_prior(model=model)
                        if hasattr(loss_prior, 'requires_grad'):
                            loss_prior.backward()

                        model[-1].step()
                        model[-1].zero_grad()

                        # monitoring
                        if (eps_count +
                                1) % self.config['minibatch_print'] == 0:
                            loss_monitor /= self.config['minibatch_print']
                            KL_monitor = KL_monitor * self.config[
                                'minibatch'] / self.config['minibatch_print']

                            # calculate step for Tensorboard Summary Writer
                            global_step = (
                                epoch_id *
                                self.config['num_episodes_per_epoch'] +
                                eps_count +
                                1) // self.config['minibatch_print']

                            tb_writer.add_scalar(tag='Cls loss',
                                                 scalar_value=loss_monitor,
                                                 global_step=global_step)
                            tb_writer.add_scalar(tag='KL divergence',
                                                 scalar_value=KL_monitor,
                                                 global_step=global_step)

                            # reset monitoring variables
                            loss_monitor = 0.
                            KL_monitor = 0.

                print(f"epoch {epoch_id} acc: {correct / total}")
                # save model
                checkpoint = {
                    'hyper_net_state_dict': model[0].state_dict(),
                    'opt_state_dict': model[-1].state_dict(),
                    'epoch': epoch_id,
                    'iters': self.config['iters']
                }

                checkpoint_path = os.path.join(
                    self.config['logdir'],
                    'run_{}.pt'.format(self.config['run']))
                torch.save(obj=checkpoint, f=checkpoint_path)
                print('State dictionaries are saved into {0:s}\n'.format(
                    checkpoint_path))

            print('Training is completed.')
        finally:
            print('\nClose tensorboard summary writer')
            tb_writer.close()

        return None