def loss_function(self, recon_x, x, y_hat=None, y_target=None, scores=None, mu=None, logvar=None): '''Calculate and return various losses that could be used for training and/or evaluating the model. INPUT: - [x_recon] <4D-tensor> reconstructed image in same shape as [x] - [x] <4D-tensor> original image - [y_hat] <2D-tensor> with predicted "logits" for each class - [y_target] <1D-tensor> with target-classes (as integers) - [scores] <2D-tensor> with target "logits" for each class - [mu] <2D-tensor> with either [z] or the estimated mean of [z] - [logvar] None or <2D-tensor> with estimated log(SD^2) of [z] OUTPUT: - [reconL] reconstruction loss indicating how well [x] and [x_recon] match - [variatL] variational (KL-divergence) loss "indicating how normally distributed [z] is" - [predL] prediction loss indicating how well targets [y] are predicted - [distilL] knowledge distillation (KD) loss indicating how well the predicted "logits" ([y_hat]) match the target "logits" ([scores])''' batch_size = x.size(0) ###-----Reconstruction loss-----### reconL = self.recon_criterion(recon_x.view(batch_size, -1), x.view(batch_size, -1)) ###-----Variational loss-----### if logvar is not None: #---- see Appendix B from: Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 ----# variatL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size # -normalise by same number of elements as in reconstruction variatL /= (self.image_channels * self.image_size**2) # --> because self.recon_criterion averages over batch-size but also over all pixels/elements in recon!! else: variatL = torch.tensor(0., device=self._device()) ###-----Prediction loss-----### if y_target is not None: predL = F.cross_entropy(y_hat, y_target, size_average=True) else: predL = torch.tensor(0., device=self._device()) ###-----Distilliation loss-----### if scores is not None: n_classes_to_consider = y_hat.size( 1 ) #--> zeroes will be added to [scores] to make its size match [y_hat] distilL = utils.loss_fn_kd(scores=y_hat[:, :n_classes_to_consider], target_scores=scores, T=self.KD_temp) else: distilL = torch.tensor(0., device=self._device()) # Return a tuple of the calculated losses return reconL, variatL, predL, distilL
def train_kd(loader, model, loss_fn_kd, optimizer, epoch, alpha, temperature, use_cuda): global BEST_ACC, LR_STATE # switch to train mode if not cfg.CLS.fix_bn: model.train() else: model.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() for batch_idx, (inputs, targets) in enumerate(loader): # adjust learning rate adjust_learning_rate(optimizer, epoch, batch=batch_idx, batch_per_epoch=len(loader)) # measure data loading time data_time.update(time.time() - end) if use_cuda: inputs, targets = inputs.cuda(), targets.cuda(async=True) inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) # forward pass: compute output outputs = model(inputs)[0] teacher_outputs = model(inputs)[1] # forward pass: compute gradient and do SGD step optimizer.zero_grad() loss = loss_fn_kd(outputs, targets, teacher_outputs, alpha, temperature) # backward loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # measure accuracy and record loss prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) losses.update(loss.data[0], inputs.size(0)) top1.update(prec1[0], inputs.size(0)) top5.update(prec5[0], inputs.size(0)) if (batch_idx + 1) % cfg.CLS.disp_iter == 0: print('Training: [{}/{}][{}/{}] | Best_Acc: {:4.2f}% | Time: {:.2f} | Data: {:.2f} | ' 'LR: {:.8f} | Top1: {:.4f}% | Top5: {:.4f}% | Loss: {:.4f} | Total: {:.2f}' .format(epoch + 1, cfg.CLS.epochs, batch_idx + 1, len(loader), BEST_ACC, batch_time.average(), data_time.average(), LR_STATE, top1.avg, top5.avg, losses.avg, batch_time.sum)) return (losses.avg, top1.avg)
def loss_function(self, recon_x, x, y_hat=None, y_target=None, scores=None, mu=None, logvar=None): """Calculate and return various losses that could be used for training and/or evaluating the model. INPUT: - [recon_x] <4D-tensor> reconstructed image in same shape as [x] - [x] <4D-tensor> original image - [y_hat] <2D-tensor> with predicted "logits" for each class - [y_target] <1D-tensor> with target-classes (as integers) - [scores] <2D-tensor> with target "logits" for each class - [mu] <2D-tensor> with either [z] or the estimated mean of [z] - [logvar] None or <2D-tensor> with estimated log(SD^2) of [z] SETTING:- [self.average] <bool>, if True, both [reconL] and [variatL] are divided by number of input elements OUTPUT: - [reconL] reconstruction loss indicating how well [x] and [x_recon] match - [variatL] variational (KL-divergence) loss "indicating how normally distributed [z] is" - [predL] prediction loss indicating how well targets [y] are predicted - [distilL] knowledge distillation (KD) loss indicating how well the predicted "logits" ([y_hat]) match the target "logits" ([scores])""" ###-----Reconstruction loss-----### reconL = self.calculate_recon_loss(x=x, x_recon=recon_x, average=self.average) # -> possibly average over pixels reconL = torch.mean(reconL) # -> average over batch ###-----Variational loss-----### if logvar is not None: variatL = self.calculate_variat_loss(mu=mu, logvar=logvar) variatL = torch.mean(variatL) # -> average over batch if self.average: variatL /= ( self.image_channels * self.image_size ** 2) # -> divide by # of input-pixels, if [self.average] else: variatL = torch.tensor(0., device=self._device()) ###-----Prediction loss-----### if y_target is not None: predL = F.cross_entropy(y_hat, y_target, reduction='elementwise_mean') # -> average over batch else: predL = torch.tensor(0., device=self._device()) ###-----Distilliation loss-----### if scores is not None: n_classes_to_consider = y_hat.size(1) # --> zeroes will be added to [scores] to make its size match [y_hat] distilL = utils.loss_fn_kd(scores=y_hat[:, :n_classes_to_consider], target_scores=scores, T=self.KD_temp) else: distilL = torch.tensor(0., device=self._device()) # Return a tuple of the calculated losses return reconL, variatL, predL, distilL
def train_a_batch(self, x, y, x_=None, y_=None, scores_=None, rnt=0.5, active_classes=None, task=1): '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_/scores_]). [x] <tensor> batch of inputs (could be None, in which case only 'replayed' data is used) [y] <tensor> batch of corresponding labels [x_] None or (<list> of) <tensor> batch of replayed inputs [y_] None or (<list> of) <tensor> batch of corresponding "replayed" labels [scores_] None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_] [rnt] <number> in [0,1], relative importance of new task [active_classes] None or (<list> of) <list> with "active" classes''' # Set model to training-mode self.train() ##--(1)-- CURRENT DATA --## if x is not None: # If requested, apply correct task-specific mask if self.mask_dict is not None: self.apply_XdGmask(task=task) # Run model y_hat = self(x) # -if needed, remove predictions for classes not in current task if active_classes is not None: class_entries = active_classes[-1] if type( active_classes[0]) == list else active_classes y_hat = y_hat[:, class_entries] # Calculate prediction loss predL = None if y is None else F.cross_entropy(y_hat, y) # Calculate training-precision precision = None if y is None else ( y == y_hat.max(1)[1]).sum().item() / x.size(0) else: precision = predL = None # -> it's possible there is only "replay" [i.e., for offline with incremental task learning] ##--(2)-- REPLAYED DATA --## if x_ is not None: # If [x_] is a list, perform separate replay for each entry n_replays = len(x_) if type(x_) == list else 1 if not type(x_) == list: x_ = [x_] y_ = [y_] scores_ = [scores_] active_classes = [active_classes] if (active_classes is not None) else None # Prepare lists to store losses for each replay loss_replay = [None] * n_replays predL_r = [None] * n_replays distilL_r = [None] * n_replays # Loop to perform each replay for replay_id in range(n_replays): # Run model y_hat = self(x_[replay_id]) # -if needed (e.g., incremental/multihead set-up), remove predictions for classes not in replayed task if active_classes is not None: y_hat = y_hat[:, active_classes[replay_id]] # Calculate losses if (y_ is not None) and (y_[replay_id] is not None): predL_r[replay_id] = F.cross_entropy(y_hat, y_[replay_id]) if (scores_ is not None) and (scores_[replay_id] is not None): distilL_r[replay_id] = utils.loss_fn_kd( scores=y_hat, target_scores=scores_[replay_id], T=self.KD_temp) # Weigh losses if self.replay_targets == "hard": loss_replay[replay_id] = predL_r[replay_id] elif self.replay_targets == "soft": loss_replay[replay_id] = distilL_r[replay_id] # Calculate total loss if x is None: loss_total = sum(loss_replay) / n_replays elif x_ is None: loss_total = predL else: loss_replay = sum(loss_replay) / n_replays loss_total = rnt * predL + (1 - rnt) * loss_replay ##--(3)-- ALLOCATION LOSSES --## # Add SI-loss (Zenke et al., 2017) surrogate_loss = self.surrogate_loss() if self.si_c > 0: loss_total += self.si_c * surrogate_loss # Add EWC-loss ewc_loss = self.ewc_loss() if self.ewc_lambda > 0: loss_total += self.ewc_lambda * ewc_loss # Reset optimizer self.optimizer.zero_grad() # Backpropagate errors loss_total.backward() # Take optimization-step self.optimizer.step() # Return the dictionary with different training-loss split in categories return { 'loss_total': loss_total.item(), 'pred': predL.item() if predL is not None else 0, 'pred_r': sum(predL_r).item() / n_replays if (x_ is not None and predL_r[0] is not None) else 0, 'distil_r': sum(distilL_r).item() / n_replays if (x_ is not None and distilL_r[0] is not None) else 0, 'ewc': ewc_loss.item(), 'si_loss': surrogate_loss.item(), 'precision': precision if precision is not None else 0., }
def train_kd(model, teacher_model, optimizer, loss_fn_kd, T, alpah): # set student model to training mode model.train() teacher_model.eval() lr = cfg.LR batch_size = cfg.BATCH_SIZE #每一个epoch含有多少个batch max_batch = len(train_datasets) // batch_size epoch_size = len(train_datasets) // batch_size ## 训练max_epoch个epoch max_iter = cfg.MAX_EPOCH * epoch_size start_iter = cfg.RESUME_EPOCH * epoch_size epoch = cfg.RESUME_EPOCH # cosine学习率调整 warmup_epoch = 5 warmup_steps = warmup_epoch * epoch_size global_step = 0 # step 学习率调整参数 stepvalues = (10 * epoch_size, 20 * epoch_size, 30 * epoch_size) step_index = 0 for iteration in range(start_iter, max_iter): global_step += 1 ##更新迭代器 if iteration % epoch_size == 0: # create batch iterator batch_iterator = iter(train_dataloader) loss = 0 epoch += 1 ###保存模型 if epoch % 5 == 0 and epoch > 0: if cfg.GPUS > 1: checkpoint = { 'model': model.module, 'model_state_dict': model.module.state_dict(), # 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch } torch.save( checkpoint, os.path.join(save_folder, 'epoch_{}.pth'.format(epoch))) else: checkpoint = { 'model': model, 'model_state_dict': model.state_dict(), # 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch } torch.save( checkpoint, os.path.join(save_folder, 'epoch_{}.pth'.format(epoch))) if iteration in stepvalues: step_index += 1 lr = adjust_learning_rate_step(optimizer, cfg.LR, 0.1, epoch, step_index, iteration, epoch_size) ## 调整学习率 # lr = adjust_learning_rate_cosine(optimizer, global_step=global_step, # learning_rate_base=cfg.LR, # total_steps=max_iter, # warmup_steps=warmup_steps) ## 获取image 和 label # try: images, labels = next(batch_iterator) # except: # continue ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装 if torch.cuda.is_available(): images, labels = images.cuda(), labels.cuda() teacher_outputs = teacher_model(images) out = model(images) loss = loss_fn_kd(out, labels, teacher_outputs, T, alpha) optimizer.zero_grad() # 清空梯度信息,否则在每次进行反向传播时都会累加 loss.backward() # loss反向传播 optimizer.step() ##梯度更新 prediction = torch.max(out, 1)[1] train_correct = (prediction == labels).sum() ##这里得到的train_correct是一个longtensor型,需要转换为float # print(train_correct.type()) train_acc = (train_correct.float()) / batch_size if iteration % 10 == 0: print('Epoch:' + repr(epoch) + ' || epochiter: ' + repr(iteration % epoch_size) + '/' + repr(epoch_size) + '|| Totel iter ' + repr(iteration) + ' || Loss: %.6f||' % (loss.item()) + 'ACC: %.3f ||' % (train_acc * 100) + 'LR: %.8f' % (lr))