def loss_function(self,
                      recon_x,
                      x,
                      y_hat=None,
                      y_target=None,
                      scores=None,
                      mu=None,
                      logvar=None):
        '''Calculate and return various losses that could be used for training and/or evaluating the model.

        INPUT:  - [x_recon]         <4D-tensor> reconstructed image in same shape as [x]
                - [x]               <4D-tensor> original image
                - [y_hat]           <2D-tensor> with predicted "logits" for each class
                - [y_target]        <1D-tensor> with target-classes (as integers)
                - [scores]          <2D-tensor> with target "logits" for each class
                - [mu]              <2D-tensor> with either [z] or the estimated mean of [z]
                - [logvar]          None or <2D-tensor> with estimated log(SD^2) of [z]

        OUTPUT: - [reconL]       reconstruction loss indicating how well [x] and [x_recon] match
                - [variatL]      variational (KL-divergence) loss "indicating how normally distributed [z] is"
                - [predL]        prediction loss indicating how well targets [y] are predicted
                - [distilL]      knowledge distillation (KD) loss indicating how well the predicted "logits" ([y_hat])
                                     match the target "logits" ([scores])'''

        batch_size = x.size(0)

        ###-----Reconstruction loss-----###
        reconL = self.recon_criterion(recon_x.view(batch_size, -1),
                                      x.view(batch_size, -1))

        ###-----Variational loss-----###
        if logvar is not None:
            #---- see Appendix B from: Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 ----#
            variatL = -0.5 * torch.sum(1 + logvar - mu.pow(2) -
                                       logvar.exp()) / batch_size
            # -normalise by same number of elements as in reconstruction
            variatL /= (self.image_channels * self.image_size**2)
            # --> because self.recon_criterion averages over batch-size but also over all pixels/elements in recon!!
        else:
            variatL = torch.tensor(0., device=self._device())

        ###-----Prediction loss-----###
        if y_target is not None:
            predL = F.cross_entropy(y_hat, y_target, size_average=True)
        else:
            predL = torch.tensor(0., device=self._device())

        ###-----Distilliation loss-----###
        if scores is not None:
            n_classes_to_consider = y_hat.size(
                1
            )  #--> zeroes will be added to [scores] to make its size match [y_hat]
            distilL = utils.loss_fn_kd(scores=y_hat[:, :n_classes_to_consider],
                                       target_scores=scores,
                                       T=self.KD_temp)
        else:
            distilL = torch.tensor(0., device=self._device())

        # Return a tuple of the calculated losses
        return reconL, variatL, predL, distilL
Example #2
0
def train_kd(loader, model, loss_fn_kd, optimizer, epoch, alpha, temperature, use_cuda):
    global BEST_ACC, LR_STATE
    # switch to train mode
    if not cfg.CLS.fix_bn:
        model.train()
    else:
        model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    for batch_idx, (inputs, targets) in enumerate(loader):
        # adjust learning rate
        adjust_learning_rate(optimizer, epoch, batch=batch_idx, batch_per_epoch=len(loader))

        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda(async=True)
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # forward pass: compute output
        outputs = model(inputs)[0]
        teacher_outputs = model(inputs)[1]
        # forward pass: compute gradient and do SGD step
        optimizer.zero_grad()
        loss = loss_fn_kd(outputs, targets, teacher_outputs, alpha, temperature)
        # backward
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.data[0], inputs.size(0))
        top1.update(prec1[0], inputs.size(0))
        top5.update(prec5[0], inputs.size(0))

        if (batch_idx + 1) % cfg.CLS.disp_iter == 0:
            print('Training: [{}/{}][{}/{}] | Best_Acc: {:4.2f}% | Time: {:.2f} | Data: {:.2f} | '
                  'LR: {:.8f} | Top1: {:.4f}% | Top5: {:.4f}% | Loss: {:.4f} | Total: {:.2f}'
                  .format(epoch + 1, cfg.CLS.epochs, batch_idx + 1, len(loader), BEST_ACC, batch_time.average(),
                          data_time.average(), LR_STATE, top1.avg, top5.avg, losses.avg, batch_time.sum))

    return (losses.avg, top1.avg)
    def loss_function(self, recon_x, x, y_hat=None, y_target=None, scores=None, mu=None, logvar=None):
        """Calculate and return various losses that could be used for training and/or evaluating the model.

        INPUT:  - [recon_x]         <4D-tensor> reconstructed image in same shape as [x]
                - [x]               <4D-tensor> original image
                - [y_hat]           <2D-tensor> with predicted "logits" for each class
                - [y_target]        <1D-tensor> with target-classes (as integers)
                - [scores]          <2D-tensor> with target "logits" for each class
                - [mu]              <2D-tensor> with either [z] or the estimated mean of [z]
                - [logvar]          None or <2D-tensor> with estimated log(SD^2) of [z]

        SETTING:- [self.average]    <bool>, if True, both [reconL] and [variatL] are divided by number of input elements

        OUTPUT: - [reconL]       reconstruction loss indicating how well [x] and [x_recon] match
                - [variatL]      variational (KL-divergence) loss "indicating how normally distributed [z] is"
                - [predL]        prediction loss indicating how well targets [y] are predicted
                - [distilL]      knowledge distillation (KD) loss indicating how well the predicted "logits" ([y_hat])
                                     match the target "logits" ([scores])"""

        ###-----Reconstruction loss-----###
        reconL = self.calculate_recon_loss(x=x, x_recon=recon_x,
                                           average=self.average)  # -> possibly average over pixels
        reconL = torch.mean(reconL)  # -> average over batch

        ###-----Variational loss-----###
        if logvar is not None:
            variatL = self.calculate_variat_loss(mu=mu, logvar=logvar)
            variatL = torch.mean(variatL)  # -> average over batch
            if self.average:
                variatL /= (
                        self.image_channels * self.image_size ** 2)  # -> divide by # of input-pixels, if [self.average]
        else:
            variatL = torch.tensor(0., device=self._device())

        ###-----Prediction loss-----###
        if y_target is not None:
            predL = F.cross_entropy(y_hat, y_target, reduction='elementwise_mean')  # -> average over batch
        else:
            predL = torch.tensor(0., device=self._device())

        ###-----Distilliation loss-----###
        if scores is not None:
            n_classes_to_consider = y_hat.size(1)  # --> zeroes will be added to [scores] to make its size match [y_hat]
            distilL = utils.loss_fn_kd(scores=y_hat[:, :n_classes_to_consider], target_scores=scores, T=self.KD_temp)
        else:
            distilL = torch.tensor(0., device=self._device())

        # Return a tuple of the calculated losses
        return reconL, variatL, predL, distilL
Example #4
0
    def train_a_batch(self,
                      x,
                      y,
                      x_=None,
                      y_=None,
                      scores_=None,
                      rnt=0.5,
                      active_classes=None,
                      task=1):
        '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_/scores_]).

        [x]               <tensor> batch of inputs (could be None, in which case only 'replayed' data is used)
        [y]               <tensor> batch of corresponding labels
        [x_]              None or (<list> of) <tensor> batch of replayed inputs
        [y_]              None or (<list> of) <tensor> batch of corresponding "replayed" labels
        [scores_]         None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_]
        [rnt]             <number> in [0,1], relative importance of new task
        [active_classes]  None or (<list> of) <list> with "active" classes'''

        # Set model to training-mode
        self.train()

        ##--(1)-- CURRENT DATA --##

        if x is not None:
            # If requested, apply correct task-specific mask
            if self.mask_dict is not None:
                self.apply_XdGmask(task=task)

            # Run model
            y_hat = self(x)
            # -if needed, remove predictions for classes not in current task
            if active_classes is not None:
                class_entries = active_classes[-1] if type(
                    active_classes[0]) == list else active_classes
                y_hat = y_hat[:, class_entries]

            # Calculate prediction loss
            predL = None if y is None else F.cross_entropy(y_hat, y)

            # Calculate training-precision
            precision = None if y is None else (
                y == y_hat.max(1)[1]).sum().item() / x.size(0)
        else:
            precision = predL = None
            # -> it's possible there is only "replay" [i.e., for offline with incremental task learning]

        ##--(2)-- REPLAYED DATA --##

        if x_ is not None:
            # If [x_] is a list, perform separate replay for each entry
            n_replays = len(x_) if type(x_) == list else 1
            if not type(x_) == list:
                x_ = [x_]
                y_ = [y_]
                scores_ = [scores_]
                active_classes = [active_classes] if (active_classes
                                                      is not None) else None

            # Prepare lists to store losses for each replay
            loss_replay = [None] * n_replays
            predL_r = [None] * n_replays
            distilL_r = [None] * n_replays

            # Loop to perform each replay
            for replay_id in range(n_replays):

                # Run model
                y_hat = self(x_[replay_id])
                # -if needed (e.g., incremental/multihead set-up), remove predictions for classes not in replayed task
                if active_classes is not None:
                    y_hat = y_hat[:, active_classes[replay_id]]

                # Calculate losses
                if (y_ is not None) and (y_[replay_id] is not None):
                    predL_r[replay_id] = F.cross_entropy(y_hat, y_[replay_id])
                if (scores_ is not None) and (scores_[replay_id] is not None):
                    distilL_r[replay_id] = utils.loss_fn_kd(
                        scores=y_hat,
                        target_scores=scores_[replay_id],
                        T=self.KD_temp)
                # Weigh losses
                if self.replay_targets == "hard":
                    loss_replay[replay_id] = predL_r[replay_id]
                elif self.replay_targets == "soft":
                    loss_replay[replay_id] = distilL_r[replay_id]

        # Calculate total loss
        if x is None:
            loss_total = sum(loss_replay) / n_replays
        elif x_ is None:
            loss_total = predL
        else:
            loss_replay = sum(loss_replay) / n_replays
            loss_total = rnt * predL + (1 - rnt) * loss_replay

        ##--(3)-- ALLOCATION LOSSES --##

        # Add SI-loss (Zenke et al., 2017)
        surrogate_loss = self.surrogate_loss()
        if self.si_c > 0:
            loss_total += self.si_c * surrogate_loss

        # Add EWC-loss
        ewc_loss = self.ewc_loss()
        if self.ewc_lambda > 0:
            loss_total += self.ewc_lambda * ewc_loss

        # Reset optimizer
        self.optimizer.zero_grad()
        # Backpropagate errors
        loss_total.backward()
        # Take optimization-step
        self.optimizer.step()

        # Return the dictionary with different training-loss split in categories
        return {
            'loss_total':
            loss_total.item(),
            'pred':
            predL.item() if predL is not None else 0,
            'pred_r':
            sum(predL_r).item() / n_replays if
            (x_ is not None and predL_r[0] is not None) else 0,
            'distil_r':
            sum(distilL_r).item() / n_replays if
            (x_ is not None and distilL_r[0] is not None) else 0,
            'ewc':
            ewc_loss.item(),
            'si_loss':
            surrogate_loss.item(),
            'precision':
            precision if precision is not None else 0.,
        }
def train_kd(model, teacher_model, optimizer, loss_fn_kd, T, alpah):

    # set student model to training mode
    model.train()
    teacher_model.eval()

    lr = cfg.LR

    batch_size = cfg.BATCH_SIZE
    #每一个epoch含有多少个batch
    max_batch = len(train_datasets) // batch_size
    epoch_size = len(train_datasets) // batch_size
    ## 训练max_epoch个epoch
    max_iter = cfg.MAX_EPOCH * epoch_size

    start_iter = cfg.RESUME_EPOCH * epoch_size

    epoch = cfg.RESUME_EPOCH

    # cosine学习率调整
    warmup_epoch = 5
    warmup_steps = warmup_epoch * epoch_size
    global_step = 0

    # step 学习率调整参数
    stepvalues = (10 * epoch_size, 20 * epoch_size, 30 * epoch_size)
    step_index = 0

    for iteration in range(start_iter, max_iter):
        global_step += 1

        ##更新迭代器
        if iteration % epoch_size == 0:
            # create batch iterator
            batch_iterator = iter(train_dataloader)
            loss = 0
            epoch += 1
            ###保存模型
            if epoch % 5 == 0 and epoch > 0:
                if cfg.GPUS > 1:
                    checkpoint = {
                        'model': model.module,
                        'model_state_dict': model.module.state_dict(),
                        # 'optimizer_state_dict': optimizer.state_dict(),
                        'epoch': epoch
                    }
                    torch.save(
                        checkpoint,
                        os.path.join(save_folder,
                                     'epoch_{}.pth'.format(epoch)))
                else:
                    checkpoint = {
                        'model': model,
                        'model_state_dict': model.state_dict(),
                        # 'optimizer_state_dict': optimizer.state_dict(),
                        'epoch': epoch
                    }
                    torch.save(
                        checkpoint,
                        os.path.join(save_folder,
                                     'epoch_{}.pth'.format(epoch)))

        if iteration in stepvalues:
            step_index += 1
        lr = adjust_learning_rate_step(optimizer, cfg.LR, 0.1, epoch,
                                       step_index, iteration, epoch_size)

        ## 调整学习率
        # lr = adjust_learning_rate_cosine(optimizer, global_step=global_step,
        #                           learning_rate_base=cfg.LR,
        #                           total_steps=max_iter,
        #                           warmup_steps=warmup_steps)

        ## 获取image 和 label
        # try:
        images, labels = next(batch_iterator)
        # except:
        #     continue

        ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
        if torch.cuda.is_available():
            images, labels = images.cuda(), labels.cuda()
        teacher_outputs = teacher_model(images)
        out = model(images)
        loss = loss_fn_kd(out, labels, teacher_outputs, T, alpha)

        optimizer.zero_grad()  # 清空梯度信息,否则在每次进行反向传播时都会累加
        loss.backward()  # loss反向传播
        optimizer.step()  ##梯度更新

        prediction = torch.max(out, 1)[1]
        train_correct = (prediction == labels).sum()
        ##这里得到的train_correct是一个longtensor型,需要转换为float
        # print(train_correct.type())
        train_acc = (train_correct.float()) / batch_size

        if iteration % 10 == 0:
            print('Epoch:' + repr(epoch) + ' || epochiter: ' +
                  repr(iteration % epoch_size) + '/' + repr(epoch_size) +
                  '|| Totel iter ' + repr(iteration) + ' || Loss: %.6f||' %
                  (loss.item()) + 'ACC: %.3f ||' % (train_acc * 100) +
                  'LR: %.8f' % (lr))