def penalty(self, out, x, alpha, curr_model): """ Compute weighted distillation loss. """ if self.prev_model is None: return 0 else: with torch.no_grad(): if isinstance(self.prev_model, MultiTaskModule): # output from previous output heads. y_prev = avalanche_forward(self.prev_model, x, None) # in a multitask scenario we need to compute the output # from all the heads, so we need to call forward again. # TODO: can we avoid this? y_curr = avalanche_forward(curr_model, x, None) else: # no task labels y_prev = {"0": self.prev_model(x)} y_curr = {"0": out} dist_loss = 0 for task_id in y_prev.keys(): # compute kd only for previous heads. if task_id in self.prev_classes: yp = y_prev[task_id] yc = y_curr[task_id] au = self.prev_classes[task_id] dist_loss += self._distillation_loss(yc, yp, au) return alpha * dist_loss
def before_training_iteration(self, strategy, **kwargs): """ Compute gradient constraints on previous memory samples from all experiences. """ if strategy.clock.train_exp_counter > 0: G = [] strategy.model.train() for t in range(strategy.clock.train_exp_counter): strategy.model.train() strategy.optimizer.zero_grad() xref = self.memory_x[t].to(strategy.device) yref = self.memory_y[t].to(strategy.device) out = avalanche_forward(strategy.model, xref, self.memory_tid[t]) loss = strategy._criterion(out, yref) loss.backward() G.append( torch.cat( [ p.grad.flatten() if p.grad is not None else torch.zeros(p.numel(), device=strategy.device) for p in strategy.model.parameters() ], dim=0, )) self.G = torch.stack(G) # (experiences, parameters)
def before_training_iteration(self, strategy, **kwargs): """ Compute reference gradient on memory sample. """ if len(self.buffers) > 0: strategy.model.train() strategy.optimizer.zero_grad() mb = self.sample_from_memory() xref, yref, tid = mb[0], mb[1], mb[-1] xref, yref = xref.to(strategy.device), yref.to(strategy.device) out = avalanche_forward(strategy.model, xref, tid) loss = strategy._criterion(out, yref) loss.backward() self.reference_gradients = [ p.grad.view(-1) for n, p in strategy.model.named_parameters() if p.requires_grad ] self.reference_gradients = torch.cat(self.reference_gradients) strategy.optimizer.zero_grad()
def before_training_iteration(self, strategy, **kwargs): """ Compute reference gradient on memory sample. """ if len(self.buffers) > 0: strategy.model.train() strategy.optimizer.zero_grad() mb = self.sample_from_memory() xref, yref, tid = mb[0], mb[1], mb[-1] xref, yref = xref.to(strategy.device), yref.to(strategy.device) out = avalanche_forward(strategy.model, xref, tid) loss = strategy._criterion(out, yref) loss.backward() # gradient can be None for some head on multi-headed models self.reference_gradients = [ p.grad.view(-1) if p.grad is not None else torch.zeros( p.numel(), device=strategy.device) for n, p in strategy.model.named_parameters() ] self.reference_gradients = torch.cat(self.reference_gradients) strategy.optimizer.zero_grad()
def forward(self): """Compute the model's output given the current mini-batch.""" return avalanche_forward(self.model, self.mb_x, self.mb_task_id)