def evaluate(
        self, eps_generator: typing.Union[OmniglotLoader, ImageFolderGenerator]
    ) -> typing.List[float]:
        """Evaluate the performance
        """
        print('Evaluation is started.\n')
        # load model
        model = self.load_model(resume_epoch=self.config['resume_epoch'],
                                hyper_net_class=self.hyper_net_class,
                                eps_generator=eps_generator)

        # get list of episode names, each episode name consists of classes
        eps = get_episodes(episode_file_path=self.config['episode_file'])

        accuracies = [None] * len(eps)

        for i, eps_name in enumerate(eps):
            eps_data = eps_generator.generate_episode(episode_name=eps_name)
            # split data into train and validation
            xt, yt, xv, yv = train_val_split(X=eps_data,
                                             k_shot=self.config['k_shot'],
                                             shuffle=True)

            # move data to GPU (if there is a GPU)
            x_t = torch.from_numpy(xt).float().to(self.config['device'])
            y_t = torch.tensor(yt,
                               dtype=torch.long,
                               device=self.config['device'])
            x_v = torch.from_numpy(xv).float().to(self.config['device'])
            y_v = torch.tensor(yv,
                               dtype=torch.long,
                               device=self.config['device'])

            _, logits = self.adapt_and_predict(model=model,
                                               x_t=x_t,
                                               y_t=y_t,
                                               x_v=x_v,
                                               y_v=None)

            # initialize y_prediction
            y_pred = torch.zeros(size=(y_v.shape[0], len(eps_data)),
                                 dtype=torch.float,
                                 device=self.config['device'])
            for logits_ in logits:
                y_pred += torch.softmax(input=logits_, dim=1)
            y_pred /= len(logits)

            accuracies[i] = (y_pred.argmax(dim=1) == y_v).float().mean().item()

            sys.stdout.write('\033[F')
            print(i + 1)

        acc_mean = np.mean(a=accuracies)
        acc_std = np.std(a=accuracies)
        print('\nAccuracy = {0:.2f} +/- {1:.2f}\n'.format(
            acc_mean * 100, 1.96 * acc_std / np.sqrt(len(accuracies)) * 100))
        return accuracies
Exemplo n.º 2
0
def evaluate() -> None:
    assert resume_epoch > 0

    acc = []

    if (num_episodes is None) and (episode_file is None):
        raise ValueError('Expect exactly one of num_episodes and episode_file to be not None, receive both are None.')

    # load model
    net, _, _ = load_model(epoch_id=resume_epoch)
    net.eval()

    episodes = get_episodes(episode_file_path=episode_file)
    if None in episodes:
        episodes = [None] * num_episodes

    file_acc = os.path.join(logdir, 'accuracy.txt')
    f_acc = open(file=file_acc, mode='w')
    try:
        for i, episode_name in enumerate(episodes):
            X = eps_generator.generate_episode(episode_name=episode_name)
                    
            # split into train and validation
            xt, yt, xv, yv = train_val_split(X=X, k_shot=k_shot, shuffle=True)

            # move data to gpu
            x_t = torch.from_numpy(xt).float().to(device)
            y_t = torch.tensor(yt, dtype=torch.long, device=device)
            x_v = torch.from_numpy(xv).float().to(device)
            y_v = torch.tensor(yv, dtype=torch.long, device=device)

            # adapt on the support data
            fnet = adapt_to_episode(x=x_t, y=y_t, net=net)

            # evaluate on the query data
            logits_v = fnet(x_v)
            episode_acc = (logits_v.argmax(dim=1) == y_v).sum().item() / (len(X) * v_shot)

            acc.append(episode_acc)
            f_acc.write('{}\n'.format(episode_acc))

            sys.stdout.write('\033[F')
            print(i)
    except:
        pass
    else:
        pass
    finally:
        f_acc.close()
    
    mean = np.mean(a=acc)
    std = np.std(a=acc)
    n = len(acc)
    print('Accuracy = {0:.4f} +/- {1:.4f}'.format(mean, 1.96 * std / np.sqrt(n)))
Exemplo n.º 3
0
def evaluate() -> None:
    assert resume_epoch > 0

    acc = []

    # load model
    net, _, _ = load_model(epoch_id=resume_epoch)
    net.eval()

    episodes = get_episodes(episode_file_path=episode_file)
    if None in episodes:
        episodes = [None] * num_episodes

    file_acc = os.path.join(logdir, 'accuracy.txt')
    f_acc = open(file=file_acc, mode='w')
    try:
        with torch.no_grad():
            for i, episode_ in enumerate(episodes):
                X = eps_generator.generate_episode(episode_name=episode_)
                        
                # split into train and validation
                xt, yt, xv, yv = train_val_split(X=X, k_shot=k_shot, shuffle=True)

                # move data to gpu
                x_t = torch.from_numpy(xt).float().to(device)
                y_t = torch.tensor(yt, dtype=torch.long, device=device)
                x_v = torch.from_numpy(xv).float().to(device)
                y_v = torch.tensor(yv, dtype=torch.long, device=device)

                # adapt on the support data
                z_prototypes = adapt_to_episode(x=x_t, y=y_t, net=net)

                # evaluate on the query data
                z_v = net.forward(x_v)
                distance_matrix = euclidean_distance(matrixN=z_v, matrixM=z_prototypes)
                logits_v = -distance_matrix
                episode_acc = (logits_v.argmax(dim=1) == y_v).sum().item() / (len(X) * v_shot)

                acc.append(episode_acc)
                f_acc.write('{0}\n'.format(episode_acc))

                sys.stdout.write('\033[F')
                print(i)
    except:
        pass
    else:
        pass
    finally:
        f_acc.close()
    
    mean = np.mean(a=acc)
    std = np.std(a=acc)
    n = len(acc)
    print('Accuracy = {0:.4f} +/- {1:.4f}'.format(mean, 1.96 * std / np.sqrt(n)))
Exemplo n.º 4
0
def train() -> None:
    """Train
    
    Args:

    Returns:
    """

    try:
        # parse training parameters
        meta_lr = args.meta_lr
        minibatch = args.minibatch
        minibatch_print = np.lcm(minibatch, 100)
        decay_lr = args.decay_lr

        num_episodes_per_epoch = args.num_episodes_per_epoch
        num_epochs = args.num_epochs

        # initialize/load model
        net, meta_optimizer, schdlr = load_model(epoch_id=resume_epoch, meta_lr=meta_lr, decay_lr=decay_lr)
        
        # zero grad
        meta_optimizer.zero_grad()

        # get episode list if not None -> generator of episode names, each episode name consists of classes
        episodes = get_episodes(episode_file_path=episode_file)

        # initialize a tensorboard summary writer for logging
        tb_writer = SummaryWriter(
            log_dir=logdir,
            purge_step=resume_epoch * num_episodes_per_epoch // minibatch_print if resume_epoch > 0 else None
        )

        for epoch_id in range(resume_epoch, resume_epoch + num_epochs, 1):
            episode_count = 0
            loss_monitor = 0
            
            while (episode_count < num_episodes_per_epoch):
                # get episode from the given csv file, or just return None
                episode_name = random.sample(population=episodes, k=1)[0]

                X = eps_generator.generate_episode(episode_name=episode_name)
                
                # split into train and validation
                xt, yt, xv, yv = train_val_split(X=X, k_shot=k_shot, shuffle=True)

                # move data to gpu
                x_t = torch.from_numpy(xt).float().to(device)
                y_t = torch.tensor(yt, dtype=torch.long, device=device)
                x_v = torch.from_numpy(xv).float().to(device)
                y_v = torch.tensor(yv, dtype=torch.long, device=device)

                # adapt on the support data
                fnet = adapt_to_episode(x=x_t, y=y_t, net=net)

                # evaluate on the query data
                logits_v = fnet.forward(x_v)
                cls_loss = torch.nn.functional.cross_entropy(input=logits_v, target=y_v)
                loss_monitor += cls_loss.item()

                cls_loss = cls_loss / minibatch
                cls_loss.backward()

                episode_count += 1

                # update the meta-model
                if (episode_count % minibatch == 0):
                    meta_optimizer.step()
                    meta_optimizer.zero_grad()

                # monitor losses
                if (episode_count % minibatch_print == 0):
                    loss_monitor /= minibatch_print
                    global_step = (epoch_id * num_episodes_per_epoch + episode_count) // minibatch_print
                    tb_writer.add_scalar(
                        tag='Loss',
                        scalar_value=loss_monitor,
                        global_step=global_step
                    )
                    loss_monitor = 0

            # decay learning rate
            schdlr.step()

            # save model
            checkpoint = {
                'net_state_dict': net.state_dict(),
                'op_state_dict': meta_optimizer.state_dict(),
                'lr_schdlr_state_dict': schdlr.state_dict()
            }
            checkpoint_filename = 'Epoch_{0:d}.pt'.format(epoch_id + 1)
            torch.save(checkpoint, os.path.join(logdir, checkpoint_filename))
            checkpoint = 0
            print('SAVING parameters into {0:s}\n'.format(checkpoint_filename))

    except KeyboardInterrupt:
        pass
    else:
        pass
    finally:
        print('\nClose tensorboard summary writer')
        tb_writer.close()
Exemplo n.º 5
0
    def load_model(
        self,
        resume_epoch: int = None,
        **kwargs
    ) -> typing.Tuple[torch.nn.Module,
                      typing.Optional[higher.patch._MonkeyPatchBase],
                      torch.optim.Optimizer]:
        """Initialize or load the protonet and its optimizer

        Args:
            resume_epoch: the index of the file containing the saved model

        Returns: a tuple consisting of
            protonet: the prototypical network
            base_net: dummy to match with MAML and VAMPIRE
            opt: the optimizer for the prototypical network
        """
        if resume_epoch is None:
            resume_epoch = self.config['resume_epoch']

        if self.config['network_architecture'] == 'CNN':
            protonet = CNN(dim_output=None, bn_affine=self.config['batchnorm'])
        elif self.config['network_architecture'] == 'ResNet18':
            protonet = ResNet18(dim_output=None,
                                bn_affine=self.config['batchnorm'])
        else:
            raise NotImplementedError(
                'Network architecture is unknown. Please implement it in the CommonModels.py.'
            )

        # ---------------------------------------------------------------
        # run a dummy task to initialize lazy modules defined in base_net
        # ---------------------------------------------------------------
        eps_data = kwargs['eps_generator'].generate_episode(episode_name=None)
        # split data into train and validation
        xt, _, _, _ = train_val_split(X=eps_data,
                                      k_shot=self.config['k_shot'],
                                      shuffle=True)
        # convert numpy data into torch tensor
        x_t = torch.from_numpy(xt).float()
        # run to initialize lazy modules
        protonet(x_t)

        # move to device
        protonet.to(self.config['device'])

        # optimizer
        opt = torch.optim.Adam(params=protonet.parameters(),
                               lr=self.config['meta_lr'])

        # load model if there is saved file
        if resume_epoch > 0:
            # path to the saved file
            checkpoint_path = os.path.join(
                self.config['logdir'], 'Epoch_{0:d}.pt'.format(resume_epoch))

            # load file
            saved_checkpoint = torch.load(
                f=checkpoint_path,
                map_location=lambda storage, loc: storage.cuda(self.config[
                    'device'].index)
                if self.config['device'].type == 'cuda' else storage)

            # load state dictionaries
            protonet.load_state_dict(
                state_dict=saved_checkpoint['hyper_net_state_dict'])
            opt.load_state_dict(state_dict=saved_checkpoint['opt_state_dict'])

            # update learning rate
            for param_group in opt.param_groups:
                if param_group['lr'] != self.config['meta_lr']:
                    param_group['lr'] = self.config['meta_lr']

        return protonet, None, opt
Exemplo n.º 6
0
def evaluate(hyper_net_cls, get_f_base_net_fn: typing.Callable,
             adapt_to_episode: typing.Callable,
             get_accuracy_fn: typing.Callable) -> None:
    """Evaluation
    """
    acc = []

    # initialize/load model
    hyper_net, base_net, _, _ = load_model(hyper_net_cls=hyper_net_cls,
                                           epoch_id=config['resume_epoch'],
                                           meta_lr=config['meta_lr'],
                                           decay_lr=config['decay_lr'])

    hyper_net.eval()
    base_net.eval()

    # get list of episode names, each episode name consists of classes
    episodes = get_episodes(episode_file_path=config['episode_file'],
                            num_episodes=config['num_episodes'])

    try:
        acc_file = open(file=os.path.join(logdir, 'accuracy.txt'), mode='w')
        for i, episode_name in enumerate(episodes):
            X = eps_generator.generate_episode(episode_name=episode_name)

            # split into train and validation
            xt, yt, xv, yv = train_val_split(X=X,
                                             k_shot=config['k_shot'],
                                             shuffle=True)

            # move data to gpu
            x_t = torch.from_numpy(xt).float().to(device)
            y_t = torch.tensor(yt, dtype=torch.long, device=device)
            x_v = torch.from_numpy(xv).float().to(device)
            y_v = torch.tensor(yv, dtype=torch.long, device=device)

            # -------------------------
            # functional base network
            # -------------------------
            f_base_net = get_f_base_net_fn(base_net=base_net)

            # -------------------------
            # adapt on the support data
            # -------------------------
            f_hyper_net = adapt_to_episode(x=x_t,
                                           y=y_t,
                                           hyper_net=hyper_net,
                                           f_base_net=f_base_net)

            # -------------------------
            # accuracy
            # -------------------------
            acc_temp = get_accuracy_fn(x=x_v,
                                       y=y_v,
                                       f_hyper_net=f_hyper_net,
                                       f_base_net=f_base_net)

            acc.append(acc_temp)
            acc_file.write('{}\n'.format(acc_temp))

            sys.stdout.write('\033[F')
            print(i)
    finally:
        acc_file.close()

    acc_mean = np.mean(acc)
    acc_std = np.std(acc)
    print('Accuracy = {} +/- {}'.format(
        acc_mean, 1.96 * acc_std / np.sqrt(len(episodes))))

    return None
Exemplo n.º 7
0
def train(hyper_net_cls, get_f_base_net_fn: typing.Callable,
          adapt_to_episode: typing.Callable,
          loss_on_query_fn: typing.Callable) -> None:
    """Base method used for training

    Args:

    """
    # initialize/load model
    hyper_net, base_net, meta_opt, schdlr = load_model(
        hyper_net_cls=hyper_net_cls,
        epoch_id=config['resume_epoch'],
        meta_lr=config['meta_lr'],
        decay_lr=config['decay_lr'])

    # zero grad
    meta_opt.zero_grad()

    # get list of episode names, each episode name consists of classes
    episodes = get_episodes(episode_file_path=config['episode_file'])

    # initialize a tensorboard summary writer for logging
    tb_writer = SummaryWriter(
        log_dir=logdir,
        purge_step=config['resume_epoch'] * config['num_episodes_per_epoch'] //
        minibatch_print if config['resume_epoch'] > 0 else None)

    try:
        for epoch_id in range(config['resume_epoch'],
                              config['resume_epoch'] + config['num_epochs'],
                              1):
            episode_count = 0
            loss_monitor = 0
            # kl_div_monitor = 0

            while (episode_count < config['num_episodes_per_epoch']):
                # get episode from the given csv file, or just return None
                episode_name = random.sample(population=episodes, k=1)[0]

                X = eps_generator.generate_episode(episode_name=episode_name)

                # split into train and validation
                xt, yt, xv, yv = train_val_split(X=X,
                                                 k_shot=config['k_shot'],
                                                 shuffle=True)

                # move data to gpu
                x_t = torch.from_numpy(xt).float().to(device)
                y_t = torch.tensor(yt, dtype=torch.long, device=device)
                x_v = torch.from_numpy(xv).float().to(device)
                y_v = torch.tensor(yv, dtype=torch.long, device=device)

                # -------------------------
                # functional base network
                # -------------------------
                f_base_net = get_f_base_net_fn(base_net=base_net)

                # -------------------------
                # adapt on the support data
                # -------------------------
                f_hyper_net = adapt_to_episode(x=x_t,
                                               y=y_t,
                                               hyper_net=hyper_net,
                                               f_base_net=f_base_net)

                # -------------------------
                # loss on query data
                # -------------------------
                loss_meta = loss_on_query_fn(x=x_v,
                                             y=y_v,
                                             f_hyper_net=f_hyper_net,
                                             f_base_net=f_base_net,
                                             hyper_net=hyper_net)

                if torch.isnan(loss_meta):
                    raise ValueError('Validation loss is NaN.')

                loss_meta = loss_meta / config['minibatch']
                loss_meta.backward()

                # monitoring validation loss
                loss_monitor += loss_meta.item()
                # kl_div_monitor += kl_loss.item()

                episode_count += 1
                # update the meta-model
                if (episode_count % config['minibatch'] == 0):
                    # torch.nn.utils.clip_grad_norm_(parameters=hyper_net.parameters(), max_norm=10)
                    meta_opt.step()
                    meta_opt.zero_grad()

                # monitor losses
                if (episode_count % minibatch_print == 0):
                    loss_monitor /= minibatch_print
                    # kl_div_monitor /= minibatch_print

                    # print('{}, {}'.format(loss_monitor, kl_div_monitor))
                    # print(loss_monitor)

                    global_step = (epoch_id * config['num_episodes_per_epoch']
                                   + episode_count) // minibatch_print
                    tb_writer.add_scalar(tag='Loss',
                                         scalar_value=loss_monitor,
                                         global_step=global_step)

                    loss_monitor = 0
                    # kl_div_monitor = 0

            # decay learning rate
            schdlr.step()

            # save model
            checkpoint = {
                'hyper_net_state_dict': hyper_net.state_dict(),
                'op_state_dict': meta_opt.state_dict(),
                'lr_schdlr_state_dict': schdlr.state_dict()
            }
            checkpoint_filename = 'Epoch_{0:d}.pt'.format(epoch_id + 1)
            torch.save(checkpoint, os.path.join(logdir, checkpoint_filename))
            checkpoint = 0
            print('SAVING parameters into {0:s}\n'.format(checkpoint_filename))
    finally:
        print('\nClose tensorboard summary writer')
        tb_writer.close()

    return None
    def load_maml_like_model(
        self,
        resume_epoch: int = None,
        **kwargs
    ) -> typing.Tuple[torch.nn.Module,
                      typing.Optional[higher.patch._MonkeyPatchBase],
                      torch.optim.Optimizer]:
        """Initialize or load the hyper-net and base-net models

        Args:
            hyper_net_class: point to the hyper-net class of interest: IdentityNet for MAML or NormalVariationalNet for VAMPIRE
            resume_epoch: the index of the file containing the saved model

        Returns: a tuple consisting of
            hypet_net: the hyper neural network
            base_net: the base neural network
            meta_opt: the optimizer for meta-parameter
        """
        if resume_epoch is None:
            resume_epoch = self.config['resume_epoch']

        if self.config['network_architecture'] == 'CNN':
            base_net = CNN(dim_output=self.config['min_way'],
                           bn_affine=self.config['batchnorm'])
        elif self.config['network_architecture'] == 'ResNet18':
            base_net = ResNet18(dim_output=self.config['min_way'],
                                bn_affine=self.config['batchnorm'])
        elif self.config['network_architecture'] == 'MiniCNN':
            base_net = MiniCNN(dim_output=self.config['min_way'],
                               bn_affine=self.config['batchnorm'])
        else:
            raise NotImplementedError(
                'Network architecture is unknown. Please implement it in the CommonModels.py.'
            )

        # ---------------------------------------------------------------
        # run a dummy task to initialize lazy modules defined in base_net
        # ---------------------------------------------------------------
        eps_data = kwargs['eps_generator'].generate_episode(episode_name=None)
        # split data into train and validation
        xt, _, _, _ = train_val_split(X=eps_data,
                                      k_shot=self.config['k_shot'],
                                      shuffle=True)
        # convert numpy data into torch tensor
        x_t = torch.from_numpy(xt).float()
        # run to initialize lazy modules
        base_net(x_t)
        params = torch.nn.utils.parameters_to_vector(
            parameters=base_net.parameters())
        print('Number of parameters of the base network = {0:d}.\n'.format(
            params.numel()))

        hyper_net = kwargs['hyper_net_class'](base_net=base_net)

        # move to device
        base_net.to(self.config['device'])
        hyper_net.to(self.config['device'])

        # optimizer
        meta_opt = torch.optim.Adam(params=hyper_net.parameters(),
                                    lr=self.config['meta_lr'])

        # load model if there is saved file
        if resume_epoch > 0:
            # path to the saved file
            checkpoint_path = os.path.join(
                self.config['logdir'], 'Epoch_{0:d}.pt'.format(resume_epoch))

            # load file
            saved_checkpoint = torch.load(
                f=checkpoint_path,
                map_location=lambda storage, loc: storage.cuda(self.config[
                    'device'].index)
                if self.config['device'].type == 'cuda' else storage)

            # load state dictionaries
            hyper_net.load_state_dict(
                state_dict=saved_checkpoint['hyper_net_state_dict'])
            meta_opt.load_state_dict(
                state_dict=saved_checkpoint['opt_state_dict'])

            # update learning rate
            for param_group in meta_opt.param_groups:
                if param_group['lr'] != self.config['meta_lr']:
                    param_group['lr'] = self.config['meta_lr']

        return hyper_net, base_net, meta_opt
    def train(
        self, eps_generator: typing.Union[OmniglotLoader,
                                          ImageFolderGenerator]) -> None:
        """Train meta-learning model

        Args:
            eps_generator: the generator that generate episodes/tasks
        """
        print('Training is started.\nLog is stored at {0:s}.\n'.format(
            self.config['logdir']))

        # initialize/load model. Please see the load_model method implemented in each specific class for further information about the model
        model = self.load_model(resume_epoch=self.config['resume_epoch'],
                                hyper_net_class=self.hyper_net_class,
                                eps_generator=eps_generator)
        model[-1].zero_grad()

        # get list of episode names, each episode name consists of classes
        eps = get_episodes(episode_file_path=self.config['episode_file'])

        # initialize a tensorboard summary writer for logging
        tb_writer = SummaryWriter(log_dir=self.config['logdir'],
                                  purge_step=self.config['resume_epoch'] *
                                  self.config['num_episodes_per_epoch'] //
                                  self.config['minibatch_print']
                                  if self.config['resume_epoch'] > 0 else None)

        try:
            for epoch_id in range(
                    self.config['resume_epoch'],
                    self.config['resume_epoch'] + self.config['num_epochs'],
                    1):
                loss_monitor = 0.
                KL_monitor = 0.
                for eps_count in range(self.config['num_episodes_per_epoch']):
                    # -------------------------
                    # get eps from the given csv file or just random (None)
                    # -------------------------
                    eps_name = random.sample(population=eps, k=1)[0]

                    # -------------------------
                    # episode data
                    # -------------------------
                    eps_data = eps_generator.generate_episode(
                        episode_name=eps_name)

                    # split data into train and validation
                    xt, yt, xv, yv = train_val_split(
                        X=eps_data, k_shot=self.config['k_shot'], shuffle=True)

                    # move data to GPU (if there is a GPU)
                    x_t = torch.from_numpy(xt).float().to(
                        self.config['device'])
                    y_t = torch.tensor(yt,
                                       dtype=torch.long,
                                       device=self.config['device'])
                    x_v = torch.from_numpy(xv).float().to(
                        self.config['device'])
                    y_v = torch.tensor(yv,
                                       dtype=torch.long,
                                       device=self.config['device'])

                    # -------------------------
                    # adapt and predict the support data
                    # -------------------------
                    f_hyper_net, logits = self.adapt_and_predict(model=model,
                                                                 x_t=x_t,
                                                                 y_t=y_t,
                                                                 x_v=x_v,
                                                                 y_v=y_v)
                    loss_v = 0.
                    for logits_ in logits:
                        loss_v_temp = torch.nn.functional.cross_entropy(
                            input=logits_, target=y_v)
                        loss_v = loss_v + loss_v_temp
                    loss_v = loss_v / len(logits)
                    loss_monitor += loss_v.item()  # monitor validation loss

                    # calculate KL divergence
                    KL_div = self.KL_divergence(model=model,
                                                f_hyper_net=f_hyper_net)
                    KL_monitor += KL_div.item() if isinstance(
                        KL_div,
                        torch.Tensor) else KL_div  # monitor KL divergence

                    # extra loss applicable for ABML only
                    loss_extra = self.loss_extra(model=model,
                                                 f_hyper_net=f_hyper_net,
                                                 x_t=x_t,
                                                 y_t=y_t)

                    # accumulate KL divergence to loss
                    loss_v = loss_v + loss_extra + self.config[
                        'KL_weight'] * KL_div
                    loss_v = loss_v / self.config['minibatch']

                    # calculate gradients w.r.t. hyper_net's parameters
                    loss_v.backward()

                    # update meta-parameters
                    if ((eps_count + 1) % self.config['minibatch'] == 0):
                        loss_prior = self.loss_prior(model=model)
                        if hasattr(loss_prior, 'requires_grad'):
                            loss_prior.backward()

                        model[-1].step()
                        model[-1].zero_grad()

                        # monitoring
                        if (eps_count +
                                1) % self.config['minibatch_print'] == 0:
                            loss_monitor /= self.config['minibatch_print']
                            KL_monitor = KL_monitor * self.config[
                                'minibatch'] / self.config['minibatch_print']

                            # calculate step for Tensorboard Summary Writer
                            global_step = (
                                epoch_id *
                                self.config['num_episodes_per_epoch'] +
                                eps_count +
                                1) // self.config['minibatch_print']

                            tb_writer.add_scalar(tag='Cls loss',
                                                 scalar_value=loss_monitor,
                                                 global_step=global_step)
                            tb_writer.add_scalar(tag='KL divergence',
                                                 scalar_value=KL_monitor,
                                                 global_step=global_step)

                            # reset monitoring variables
                            loss_monitor = 0.
                            KL_monitor = 0.

                # save model
                checkpoint = {
                    'hyper_net_state_dict': model[0].state_dict(),
                    'opt_state_dict': model[-1].state_dict()
                }
                checkpoint_path = os.path.join(
                    self.config['logdir'],
                    'Epoch_{0:d}.pt'.format(epoch_id + 1))
                torch.save(obj=checkpoint, f=checkpoint_path)
                print('State dictionaries are saved into {0:s}\n'.format(
                    checkpoint_path))

            print('Training is completed.')
        finally:
            print('\nClose tensorboard summary writer')
            tb_writer.close()

        return None