Exemplo n.º 1
0
    def train_learner(self, x_train, y_train):
        self.before_train(x_train, y_train)

        # set up loader
        train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
        train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
                                       drop_last=True)

        # set up model
        self.model = self.model.train()

        # setup tracker
        losses_batch = AverageMeter()
        acc_batch = AverageMeter()

        for ep in range(self.epoch):
            for i, batch_data in enumerate(train_loader):
                # batch update
                batch_x, batch_y = batch_data
                batch_x = maybe_cuda(batch_x, self.cuda)
                batch_y = maybe_cuda(batch_y, self.cuda)

                logits = self.forward(batch_x)
                loss_old = self.kd_manager.get_kd_loss(logits, batch_x)
                loss_new = self.criterion(logits, batch_y)
                loss = 1/(self.task_seen + 1) * loss_new + (1 - 1/(self.task_seen + 1)) * loss_old
                _, pred_label = torch.max(logits, 1)
                correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0)
                # update tracker
                acc_batch.update(correct_cnt, batch_y.size(0))
                losses_batch.update(loss, batch_y.size(0))
                # backward
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()

                if i % 100 == 1 and self.verbose:
                    print(
                        '==>>> it: {}, avg. loss: {:.6f}, '
                        'running train acc: {:.3f}'
                            .format(i, losses_batch.avg(), acc_batch.avg())
                    )
        self.after_train()
Exemplo n.º 2
0
 def evaluate(self, test_loaders):
     self.model.eval()
     acc_array = np.zeros(len(test_loaders))
     if self.params.trick['nmc_trick'] or self.params.agent == 'ICARL':
         exemplar_means = {}
         cls_exemplar = {cls: [] for cls in self.old_labels}
         buffer_filled = self.buffer.current_index
         for x, y in zip(self.buffer.buffer_img[:buffer_filled], self.buffer.buffer_label[:buffer_filled]):
             cls_exemplar[y.item()].append(x)
         for cls, exemplar in cls_exemplar.items():
             features = []
             # Extract feature for each exemplar in p_y
             for ex in exemplar:
                 feature = self.model.features(ex.unsqueeze(0)).detach().clone()
                 feature = feature.squeeze()
                 feature.data = feature.data / feature.data.norm()  # Normalize
                 features.append(feature)
             features = torch.stack(features)
             mu_y = features.mean(0).squeeze()
             mu_y.data = mu_y.data / mu_y.data.norm()  # Normalize
             exemplar_means[cls] = mu_y
     with torch.no_grad():
         if self.params.error_analysis:
             error = 0
             no = 0
             nn = 0
             oo = 0
             on = 0
             new_class_score = AverageMeter()
             old_class_score = AverageMeter()
         for task, test_loader in enumerate(test_loaders):
             acc = AverageMeter()
             for i, (batch_x, batch_y) in enumerate(test_loader):
                 batch_x = maybe_cuda(batch_x, self.cuda)
                 batch_y = maybe_cuda(batch_y, self.cuda)
                 if self.params.trick['nmc_trick'] or self.params.agent == 'ICARL':
                     feature = self.model.features(batch_x)  # (batch_size, feature_size)
                     for j in range(feature.size(0)):  # Normalize
                         feature.data[j] = feature.data[j] / feature.data[j].norm()
                     feature = feature.unsqueeze(2)  # (batch_size, feature_size, 1)
                     means = torch.stack([exemplar_means[cls] for cls in self.old_labels])  # (n_classes, feature_size)
                     means = torch.stack([means] * batch_x.size(0))  # (batch_size, n_classes, feature_size)
                     means = means.transpose(1, 2)
                     feature = feature.expand_as(means)  # (batch_size, feature_size, n_classes)
                     dists = (feature - means).pow(2).sum(1).squeeze()  # (batch_size, n_classes)
                     _, preds = dists.min(1)
                     correct_cnt = (np.array(self.old_labels)[
                                        preds.tolist()] == batch_y.cpu().numpy()).sum().item() / batch_y.size(0)
                 else:
                     logits = self.model.forward(batch_x)
                     _, pred_label = torch.max(logits, 1)
                     correct_cnt = (pred_label == batch_y).sum().item()/batch_y.size(0)
                     if self.params.error_analysis:
                         if task < self.task_seen-1:
                             # old test
                             total = (pred_label != batch_y).sum().item()
                             wrong = pred_label[pred_label != batch_y]
                             error += total
                             on_tmp = sum([(wrong == i).sum().item() for i in self.new_labels_zombie])
                             oo += total - on_tmp
                             on += on_tmp
                             old_class_score.update(logits[:, list(set(self.old_labels) - set(self.new_labels_zombie))].mean().item(), batch_y.size(0))
                         elif task == self.task_seen -1:
                             # new test
                             total = (pred_label != batch_y).sum().item()
                             error += total
                             wrong = pred_label[pred_label != batch_y]
                             no_tmp = sum([(wrong == i).sum().item() for i in list(set(self.old_labels) - set(self.new_labels_zombie))])
                             no += no_tmp
                             nn += total - no_tmp
                             new_class_score.update(logits[:, self.new_labels_zombie].mean().item(), batch_y.size(0))
                         else:
                             pass
                 acc.update(correct_cnt, batch_y.size(0))
             acc_array[task] = acc.avg()
     print(acc_array)
     if self.params.error_analysis:
         self.error_list.append((no, nn, oo, on))
         self.new_class_score.append(new_class_score.avg())
         self.old_class_score.append(old_class_score.avg())
         print("no ratio: {}\non ratio: {}".format(no/(no+nn+0.1), on/(oo+on+0.1)))
         print(self.error_list)
         print(self.new_class_score)
         print(self.old_class_score)
         self.fc_norm_new.append(self.model.linear.weight[self.new_labels_zombie].mean().item())
         self.fc_norm_old.append(self.model.linear.weight[list(set(self.old_labels) - set(self.new_labels_zombie))].mean().item())
         self.bias_norm_new.append(self.model.linear.bias[self.new_labels_zombie].mean().item())
         self.bias_norm_old.append(self.model.linear.bias[list(set(self.old_labels) - set(self.new_labels_zombie))].mean().item())
         print(self.fc_norm_old)
         print(self.fc_norm_new)
         print(self.bias_norm_old)
         print(self.bias_norm_new)
     return acc_array
Exemplo n.º 3
0
    def train_learner(self, x_train, y_train):
        self.before_train(x_train, y_train)
        # set up loader
        train_dataset = dataset_transform(
            x_train, y_train, transform=transforms_match[self.data])
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=self.batch,
                                       shuffle=True,
                                       num_workers=0,
                                       drop_last=True)
        # setup tracker
        losses_batch = AverageMeter()
        acc_batch = AverageMeter()

        # set up model
        self.model.train()

        for ep in range(self.epoch):
            for i, batch_data in enumerate(train_loader):
                # batch update
                batch_x, batch_y = batch_data
                batch_x = maybe_cuda(batch_x, self.cuda)
                batch_y = maybe_cuda(batch_y, self.cuda)

                # update the running fisher
                if (ep * len(train_loader) + i +
                        1) % self.fisher_update_after == 0:
                    self.update_running_fisher()

                out = self.forward(batch_x)
                loss = self.total_loss(out, batch_y)
                if self.params.trick['kd_trick']:
                    loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
                                   self.kd_manager.get_kd_loss(out, batch_x)
                if self.params.trick['kd_trick_star']:
                    loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \
                           (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(out, batch_x)
                # update tracker
                losses_batch.update(loss.item(), batch_y.size(0))
                _, pred_label = torch.max(out, 1)
                acc = (pred_label == batch_y).sum().item() / batch_y.size(0)
                acc_batch.update(acc, batch_y.size(0))
                # backward
                self.opt.zero_grad()
                loss.backward()

                # accumulate the fisher of current batch
                self.accum_fisher()
                self.opt.step()

                if i % 100 == 1 and self.verbose:
                    print('==>>> it: {}, avg. loss: {:.6f}, '
                          'running train acc: {:.3f}'.format(
                              i, losses_batch.avg(), acc_batch.avg()))

        # save params for current task
        for n, p in self.weights.items():
            self.prev_params[n] = p.clone().detach()

        # update normalized fisher of current task
        max_fisher = max([torch.max(m) for m in self.running_fisher.values()])
        min_fisher = min([torch.min(m) for m in self.running_fisher.values()])
        for n, p in self.running_fisher.items():
            self.normalized_fisher[n] = (p - min_fisher) / (max_fisher -
                                                            min_fisher + 1e-32)
        self.after_train()
    def train_learner(self, x_train, y_train):
        self.before_train(x_train, y_train)
        # set up loader
        train_dataset = dataset_transform(
            x_train, y_train, transform=transforms_match[self.data])
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=self.batch,
                                       shuffle=True,
                                       num_workers=0,
                                       drop_last=True)
        # set up model
        self.model = self.model.train()

        # setup tracker
        losses_batch = AverageMeter()
        losses_mem = AverageMeter()
        acc_batch = AverageMeter()
        acc_mem = AverageMeter()

        for ep in range(self.epoch):
            for i, batch_data in enumerate(train_loader):
                # batch update
                batch_x, batch_y = batch_data
                batch_x = maybe_cuda(batch_x, self.cuda)
                batch_y = maybe_cuda(batch_y, self.cuda)
                for j in range(self.mem_iters):
                    logits = self.model.forward(batch_x)
                    loss = self.criterion(logits, batch_y)
                    if self.params.trick['kd_trick']:
                        loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
                                   self.kd_manager.get_kd_loss(logits, batch_x)
                    if self.params.trick['kd_trick_star']:
                        loss = 1/((self.task_seen + 1) ** 0.5) * loss + \
                               (1 - 1/((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x)
                    _, pred_label = torch.max(logits, 1)
                    correct_cnt = (pred_label
                                   == batch_y).sum().item() / batch_y.size(0)
                    # update tracker
                    acc_batch.update(correct_cnt, batch_y.size(0))
                    losses_batch.update(loss, batch_y.size(0))
                    # backward
                    self.opt.zero_grad()
                    loss.backward()

                    # mem update
                    mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y)
                    if mem_x.size(0) > 0:
                        mem_x = maybe_cuda(mem_x, self.cuda)
                        mem_y = maybe_cuda(mem_y, self.cuda)
                        mem_logits = self.model.forward(mem_x)
                        loss_mem = self.criterion(mem_logits, mem_y)
                        if self.params.trick['kd_trick']:
                            loss_mem = 1 / (self.task_seen + 1) * loss_mem + (1 - 1 / (self.task_seen + 1)) * \
                                       self.kd_manager.get_kd_loss(mem_logits, mem_x)
                        if self.params.trick['kd_trick_star']:
                            loss_mem = 1 / ((self.task_seen + 1) ** 0.5) * loss_mem + \
                                   (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(mem_logits,
                                                                                                         mem_x)
                        # update tracker
                        losses_mem.update(loss_mem, mem_y.size(0))
                        _, pred_label = torch.max(mem_logits, 1)
                        correct_cnt = (pred_label
                                       == mem_y).sum().item() / mem_y.size(0)
                        acc_mem.update(correct_cnt, mem_y.size(0))

                        loss_mem.backward()

                    if self.params.agent == 'ASER':
                        # opt update
                        self.opt.zero_grad()
                        combined_batch = torch.cat((mem_x, batch_x))
                        combined_labels = torch.cat((mem_y, batch_y))
                        combined_logits = self.model.forward(combined_batch)
                        loss_combined = self.criterion(combined_logits,
                                                       combined_labels)
                        loss_combined.backward()
                        self.opt.step()
                    else:
                        self.opt.step()

                # update mem
                self.buffer.update(batch_x, batch_y)

                if i % 100 == 1 and self.verbose:
                    print('==>>> it: {}, avg. loss: {:.6f}, '
                          'running train acc: {:.3f}'.format(
                              i, losses_batch.avg(), acc_batch.avg()))
                    print('==>>> it: {}, mem avg. loss: {:.6f}, '
                          'running mem acc: {:.3f}'.format(
                              i, losses_mem.avg(), acc_mem.avg()))
        self.after_train()
Exemplo n.º 5
0
    def train_learner(self, x_train, y_train):
        self.before_train(x_train, y_train)

        # set up loader
        train_dataset = dataset_transform(
            x_train, y_train, transform=transforms_match[self.data])
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=self.batch,
                                       shuffle=True,
                                       num_workers=0,
                                       drop_last=True)
        # set up model
        self.model = self.model.train()

        # setup tracker
        losses_batch = AverageMeter()
        acc_batch = AverageMeter()

        for ep in range(self.epoch):
            for i, batch_data in enumerate(train_loader):
                # batch update
                batch_x, batch_y = batch_data
                batch_x = maybe_cuda(batch_x, self.cuda)
                batch_y = maybe_cuda(batch_y, self.cuda)
                for j in range(self.mem_iters):
                    logits = self.forward(batch_x)
                    loss = self.criterion(logits, batch_y)
                    if self.params.trick['kd_trick']:
                        loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
                                    self.kd_manager.get_kd_loss(logits, batch_x)
                    if self.params.trick['kd_trick_star']:
                        loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \
                               (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x)
                    _, pred_label = torch.max(logits, 1)
                    correct_cnt = (pred_label
                                   == batch_y).sum().item() / batch_y.size(0)
                    # update tracker
                    acc_batch.update(correct_cnt, batch_y.size(0))
                    losses_batch.update(loss, batch_y.size(0))
                    # backward
                    self.opt.zero_grad()
                    loss.backward()

                    if self.task_seen > 0:
                        # sample from memory of previous tasks
                        mem_x, mem_y = self.buffer.retrieve()
                        if mem_x.size(0) > 0:
                            params = [
                                p for p in self.model.parameters()
                                if p.requires_grad
                            ]
                            # gradient computed using current batch
                            grad = [p.grad.clone() for p in params]
                            mem_x = maybe_cuda(mem_x, self.cuda)
                            mem_y = maybe_cuda(mem_y, self.cuda)
                            mem_logits = self.forward(mem_x)
                            loss_mem = self.criterion(mem_logits, mem_y)
                            self.opt.zero_grad()
                            loss_mem.backward()
                            # gradient computed using memory samples
                            grad_ref = [p.grad.clone() for p in params]

                            # inner product of grad and grad_ref
                            prod = sum([
                                torch.sum(g * g_r)
                                for g, g_r in zip(grad, grad_ref)
                            ])
                            if prod < 0:
                                prod_ref = sum(
                                    [torch.sum(g_r**2) for g_r in grad_ref])
                                # do projection
                                grad = [
                                    g - prod / prod_ref * g_r
                                    for g, g_r in zip(grad, grad_ref)
                                ]
                            # replace params' grad
                            for g, p in zip(grad, params):
                                p.grad.data.copy_(g)
                    self.opt.step()
                # update mem
                self.buffer.update(batch_x, batch_y)

                if i % 100 == 1 and self.verbose:
                    print('==>>> it: {}, avg. loss: {:.6f}, '
                          'running train acc: {:.3f}'.format(
                              i, losses_batch.avg(), acc_batch.avg()))
        self.after_train()