def alternative_ewc_loss(self, task_count): loss_prev_tasks = 0 # calculate ewc loss on previous tasks by multiplying the square of the difference between the current network # parameter weights and those after training each previously encountered task, multiplied by the # Fisher diagonal computed for the respective previous task in each difference, all summed together. for task in range(1, task_count): task_weights = self.task_post_training_weights.get(task) # weights after training network on task task_fisher = self.task_fisher_diags.get(task) # fisher diagonals computed for task for param_index, parameter in enumerate(self.parameters()): # size of weights at parameter_index stored after network was trained on the previous task in question task_weights_size = torch.Tensor(list(task_weights[param_index].size())) # size of the computed fisher diagonal for the parameter in question, for the given task (in outer for loop) task_fisher_size = torch.Tensor(list(task_fisher[param_index].size())) # current size of parameter in network corresponding to the weights and fisher info above parameter_size = torch.Tensor(list(parameter.size())) # If size of tensor of weights after training previous task does not match current parameter size at corresponding # index (if, for example, we have expanded the network since training on that previous task), # pad the tensor of weights from parameter after training on given task with zeros so that it matches the # size in all dimensions of the corresponding parameter in the network if not torch.equal(task_weights_size, parameter_size): pad_tuple = utils.pad_tuple(task_weights[param_index], parameter) task_weights[param_index] = F.pad(task_weights[param_index], pad_tuple, mode='constant', value=0) # If size of fisher diagonal computed for previous task does not match current parameter size at corresponding # index (if, for example, we have expanded the network since training on that previous task), # pad the fisher diagonal for the parameter computed after training on the given task with zeros so that it matches the # size in all dimensions of the corresponding parameter in the network if not torch.equal(task_fisher_size, parameter_size): pad_tuple = utils.pad_tuple(task_fisher[param_index], parameter) task_fisher[param_index] = F.pad(task_fisher[param_index], pad_tuple, mode='constant', value=0) # add to the loss the part of the original summed ewc loss term corresponding to the specific task and parameter # in question (specified by the two for loops in this function) # (see: https://arxiv.org/pdf/1612.00796.pdf#section.2 equation 3) loss_prev_tasks += (((parameter - task_weights[param_index]) ** 2) * task_fisher[param_index]).sum() # multiply summed loss term by fisher multiplier divided by 2 return loss_prev_tasks * (self.lam / 2.0)
def expand_ewc_sums(self): ewc_sums = [self.sum_Fx, self.sum_Fx_Wx, self.sum_Fx_Wx_sq] for ewc_sum in range(len(ewc_sums)): for parameter_index, parameter in enumerate(self.parameters()): # current size of entry at parameter_index in given list of sums sum_size = torch.Tensor(list(ewc_sums[ewc_sum][parameter_index].size())) # current size of parameter in the model corresponding to the sum entry above parameter_size = torch.Tensor(list(parameter.size())) # pad the sum tensor at the current parameter index of the given sum list with zeros so that it matches the size in # all dimensions of the corresponding parameter if not torch.equal(sum_size, parameter_size): pad_tuple = utils.pad_tuple(ewc_sums[ewc_sum][parameter_index],parameter) ewc_sums[ewc_sum][parameter_index] = F.pad(ewc_sums[ewc_sum][parameter_index], pad_tuple, mode='constant', value=0)