예제 #1
0
    def penalty(self, out, x, alpha, curr_model):
        """
        Compute weighted distillation loss.
        """

        if self.prev_model is None:
            return 0
        else:
            with torch.no_grad():
                if isinstance(self.prev_model, MultiTaskModule):
                    # output from previous output heads.
                    y_prev = avalanche_forward(self.prev_model, x, None)
                    # in a multitask scenario we need to compute the output
                    # from all the heads, so we need to call forward again.
                    # TODO: can we avoid this?
                    y_curr = avalanche_forward(curr_model, x, None)
                else:  # no task labels
                    y_prev = {"0": self.prev_model(x)}
                    y_curr = {"0": out}

            dist_loss = 0
            for task_id in y_prev.keys():
                # compute kd only for previous heads.
                if task_id in self.prev_classes:
                    yp = y_prev[task_id]
                    yc = y_curr[task_id]
                    au = self.prev_classes[task_id]
                    dist_loss += self._distillation_loss(yc, yp, au)
            return alpha * dist_loss
예제 #2
0
    def before_training_iteration(self, strategy, **kwargs):
        """
        Compute gradient constraints on previous memory samples from all
        experiences.
        """

        if strategy.clock.train_exp_counter > 0:
            G = []
            strategy.model.train()
            for t in range(strategy.clock.train_exp_counter):
                strategy.model.train()
                strategy.optimizer.zero_grad()
                xref = self.memory_x[t].to(strategy.device)
                yref = self.memory_y[t].to(strategy.device)
                out = avalanche_forward(strategy.model, xref,
                                        self.memory_tid[t])
                loss = strategy._criterion(out, yref)
                loss.backward()

                G.append(
                    torch.cat(
                        [
                            p.grad.flatten() if p.grad is not None else
                            torch.zeros(p.numel(), device=strategy.device)
                            for p in strategy.model.parameters()
                        ],
                        dim=0,
                    ))

            self.G = torch.stack(G)  # (experiences, parameters)
예제 #3
0
파일: agem.py 프로젝트: gab709/avalanche
    def before_training_iteration(self, strategy, **kwargs):
        """
        Compute reference gradient on memory sample.
        """
        if len(self.buffers) > 0:
            strategy.model.train()
            strategy.optimizer.zero_grad()
            mb = self.sample_from_memory()
            xref, yref, tid = mb[0], mb[1], mb[-1]
            xref, yref = xref.to(strategy.device), yref.to(strategy.device)

            out = avalanche_forward(strategy.model, xref, tid)
            loss = strategy._criterion(out, yref)
            loss.backward()
            self.reference_gradients = [
                p.grad.view(-1) for n, p in strategy.model.named_parameters()
                if p.requires_grad
            ]
            self.reference_gradients = torch.cat(self.reference_gradients)
            strategy.optimizer.zero_grad()
예제 #4
0
    def before_training_iteration(self, strategy, **kwargs):
        """
        Compute reference gradient on memory sample.
        """
        if len(self.buffers) > 0:
            strategy.model.train()
            strategy.optimizer.zero_grad()
            mb = self.sample_from_memory()
            xref, yref, tid = mb[0], mb[1], mb[-1]
            xref, yref = xref.to(strategy.device), yref.to(strategy.device)

            out = avalanche_forward(strategy.model, xref, tid)
            loss = strategy._criterion(out, yref)
            loss.backward()
            # gradient can be None for some head on multi-headed models
            self.reference_gradients = [
                p.grad.view(-1) if p.grad is not None else torch.zeros(
                    p.numel(), device=strategy.device)
                for n, p in strategy.model.named_parameters()
            ]
            self.reference_gradients = torch.cat(self.reference_gradients)
            strategy.optimizer.zero_grad()
예제 #5
0
 def forward(self):
     """Compute the model's output given the current mini-batch."""
     return avalanche_forward(self.model, self.mb_x, self.mb_task_id)