Exemplo n.º 1
0
    def get_loss(self, forward_pass: ForwardPass, y: Tensor = None) -> Loss:
        """Get a `Loss` for the given forward pass and resulting rewards/labels.
        
        Take a look at the `AuxiliaryTask` class for more info,

        NOTE: This is the same simplified version of EWC used throughout the
        other examples: the loss is the P-norm between the current weights and
        the weights as they were on the begining of the task.
        Also note, this particular example doesn't actually use the provided
        arguments.
        """
        if self.previous_task is None:
            # We're in the first task: do nothing.
            return Loss(name=self.name)

        old_weights: Dict[str, Tensor] = self.previous_model_weights
        new_weights: Dict[str, Tensor] = dict(self.model.named_parameters())

        loss = 0.0
        for weight_name, (new_w,
                          old_w) in dict_intersection(new_weights,
                                                      old_weights):
            loss += torch.dist(new_w,
                               old_w.type_as(new_w),
                               p=self.options.distance_norm)

        ewc_loss = Loss(name=self.name, loss=loss)
        return ewc_loss
Exemplo n.º 2
0
    def get_loss(self, *args, **kwargs) -> Loss:
        """Gets the loss. 

        NOTE: This is a simplified version of EWC where the loss is the P-norm
        between the current weights and the weights as they were on the begining
        of the task.

        This doesn't actually use any of the provided arguments.
        """
        if self.previous_task is None:
            # We're in the first task: do nothing.
            return Loss(name=self.name)

        old_weights: Dict[str, Tensor] = self.previous_model_weights
        new_weights: Dict[str, Tensor] = dict(self.model.named_parameters())

        loss = 0.
        for weight_name, (new_w,
                          old_w) in dict_intersection(new_weights,
                                                      old_weights):
            loss += torch.dist(new_w,
                               old_w.type_as(new_w),
                               p=self.options.distance_norm)

        self._i += 1
        ewc_loss = Loss(
            name=self.name,
            loss=loss,
        )
        return ewc_loss
    def ewc_loss(self) -> Tensor:
        """Gets an 'ewc-like' regularization loss.

        NOTE: This is a simplified version of EWC where the loss is the P-norm
        between the current weights and the weights as they were on the begining
        of the task.
        """
        if self._previous_task is None:
            # We're in the first task: do nothing.
            return 0.0

        old_weights: Dict[str, Tensor] = self.previous_model_weights
        new_weights: Dict[str, Tensor] = dict(self.named_parameters())

        loss = 0.0
        for weight_name, (new_w, old_w) in dict_intersection(new_weights, old_weights):
            loss += torch.dist(new_w, old_w.type_as(new_w), p=self.reg_p_norm)
        return loss