def alternative_ewc_loss(self, task_count):

        loss_prev_tasks = 0

        # calculate ewc loss on previous tasks by multiplying the square of the difference between the current network
        # parameter weights and those after training each previously encountered task, multiplied by the
        # Fisher diagonal computed for the respective previous task in each difference, all summed together.
        for task in range(1, task_count):

            task_weights = self.task_post_training_weights.get(task) # weights after training network on task
            task_fisher = self.task_fisher_diags.get(task) # fisher diagonals computed for task

            for param_index, parameter in enumerate(self.parameters()):

                # size of weights at parameter_index stored after network was trained on the previous task in question
                task_weights_size = torch.Tensor(list(task_weights[param_index].size()))

                # size of the computed fisher diagonal for the parameter in question, for the given task (in outer for loop)
                task_fisher_size = torch.Tensor(list(task_fisher[param_index].size()))

                # current size of parameter in network corresponding to the weights and fisher info above
                parameter_size = torch.Tensor(list(parameter.size()))

                # If size of tensor of weights after training previous task does not match current parameter size at corresponding
                # index (if, for example, we have expanded the network since training on that previous task),
                # pad the tensor of weights from parameter after training on given task with zeros so that it matches the
                # size in all dimensions of the corresponding parameter in the network
                if not torch.equal(task_weights_size, parameter_size):
                    pad_tuple = utils.pad_tuple(task_weights[param_index], parameter)
                    task_weights[param_index] = F.pad(task_weights[param_index], pad_tuple, mode='constant', value=0)

                # If size of fisher diagonal computed for previous task does not match current parameter size at corresponding
                # index (if, for example, we have expanded the network since training on that previous task),
                # pad the fisher diagonal for the parameter computed after training on the given task with zeros so that it matches the
                # size in all dimensions of the corresponding parameter in the network
                if not torch.equal(task_fisher_size, parameter_size):
                    pad_tuple = utils.pad_tuple(task_fisher[param_index], parameter)
                    task_fisher[param_index] = F.pad(task_fisher[param_index], pad_tuple, mode='constant', value=0)

                # add to the loss the part of the original summed ewc loss term corresponding to the specific task and parameter
                # in question (specified by the two for loops in this function)
                # (see: https://arxiv.org/pdf/1612.00796.pdf#section.2  equation 3)
                loss_prev_tasks += (((parameter - task_weights[param_index]) ** 2) * task_fisher[param_index]).sum()

        # multiply summed loss term by fisher multiplier divided by 2
        return loss_prev_tasks * (self.lam / 2.0)
    def expand_ewc_sums(self):

        ewc_sums = [self.sum_Fx, self.sum_Fx_Wx, self.sum_Fx_Wx_sq]

        for ewc_sum in range(len(ewc_sums)):
            for parameter_index, parameter in enumerate(self.parameters()):
                # current size of entry at parameter_index in given list of sums
                sum_size = torch.Tensor(list(ewc_sums[ewc_sum][parameter_index].size()))

                # current size of parameter in the model corresponding to the sum entry above
                parameter_size = torch.Tensor(list(parameter.size()))

                # pad the sum tensor at the current parameter index of the given sum list with zeros so that it matches the size in
                # all dimensions of the corresponding parameter
                if not torch.equal(sum_size, parameter_size):
                    pad_tuple = utils.pad_tuple(ewc_sums[ewc_sum][parameter_index],parameter)
                    ewc_sums[ewc_sum][parameter_index] = F.pad(ewc_sums[ewc_sum][parameter_index], pad_tuple, mode='constant', value=0)