def _update_grad(self, strategy): model = strategy.model batch = strategy.mbatch model.eval() # Set RNN-like modules on GPU to training mode to avoid CUDA error if strategy.device == "cuda": for module in model.modules(): if isinstance(module, torch.nn.RNNBase): warnings.warn( "RNN-like modules do not support " "backward calls while in `eval` mode on CUDA " "devices. Setting all `RNNBase` modules to " "`train` mode. May produce inconsistent " "output if such modules have `dropout` > 0.") module.train() x, y, task_labels = batch[0], batch[1], batch[-1] strategy.optimizer.zero_grad() out = avalanche_forward(model, x, task_labels) loss = strategy._criterion(out, y) # noqa loss.backward() self.iter_grad = copy_params_dict(model, copy_grad=True)
def after_training_exp(self, strategy, *args, **kwargs): self.exp_importance = self.iter_importance self.exp_params = copy_params_dict(strategy.model) if self.exp_scores is None: self.exp_scores = self.checkpoint_scores else: exp_scores = [] for (k1, p_score), (k2, p_cp_score) in zip(self.exp_scores, self.checkpoint_scores): assert k1 == k2, "Error in RWalk score computation." exp_scores.append((k1, 0.5 * (p_score + p_cp_score))) self.exp_scores = exp_scores # Compute weight penalties once for all successive iterations # (t_k+1 variables remain constant in Eq. 8 in the paper) self.exp_penalties = [] # Normalize terms in [0,1] interval, as suggested in the paper # (the importance is already > 0, while negative scores are relu-ed # out, hence we scale only the max-values of both terms) max_score = max(map(lambda x: x[1].max(), self.exp_scores)) max_imp = max(map(lambda x: x[1].max(), self.exp_importance)) for (k1, imp), (k2, score) in zip(self.exp_importance, self.exp_scores): assert k1 == k2, "Error in RWalk penalties computation." self.exp_penalties.append( (k1, imp / max_imp + F.relu(score) / max_score)) self.checkpoint_scores = zerolike_params_dict(strategy.model)
def before_training(self, strategy: BaseSGDTemplate, **kwargs): # Parameters before the first task starts if not self.params: self.params = dict(copy_params_dict(strategy.model)) # Initialize Fisher information weight importance if not self.importance: self.importance = dict(zerolike_params_dict(strategy.model))
def after_training_iteration(self, strategy, *args, **kwargs): self._update_loss(strategy) if self._is_checkpoint_iter(strategy): self._update_score(strategy) self.checkpoint_loss = zerolike_params_dict(strategy.model) self.checkpoint_params = copy_params_dict(strategy.model)
def after_training_exp(self, strategy: BaseSGDTemplate, **kwargs): self.params = dict(copy_params_dict(strategy.model)) # Check if previous importance is available if not self.importance: raise ValueError("Importance is not available") # Get importance curr_importance = self._get_importance(strategy) # Update importance for name in self.importance.keys(): self.importance[name] = (self.alpha * self.importance[name] + (1 - self.alpha) * curr_importance[name])
def after_training_exp(self, strategy, **kwargs): """ Compute importances of parameters after each experience. """ exp_counter = strategy.clock.train_exp_counter importances = self.compute_importances( strategy.model, strategy._criterion, strategy.optimizer, strategy.experience.dataset, strategy.device, strategy.train_mb_size, ) self.update_importances(importances, exp_counter) self.saved_params[exp_counter] = copy_params_dict(strategy.model) # clear previous parameter values if exp_counter > 0 and (not self.keep_importance_data): del self.saved_params[exp_counter - 1]
def before_training_iteration(self, strategy, *args, **kwargs): self._update_grad(strategy) self._update_importance(strategy) self.iter_params = copy_params_dict(strategy.model)
def before_training(self, strategy, *args, **kwargs): self.checkpoint_loss = zerolike_params_dict(strategy.model) self.checkpoint_scores = zerolike_params_dict(strategy.model) self.checkpoint_params = copy_params_dict(strategy.model)