def train_learner(self, x_train, y_train):
        # set up loader
        train_dataset = dataset_transform(
            x_train, y_train, transform=transforms_match[self.data])
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=self.batch,
                                       shuffle=True,
                                       num_workers=0,
                                       drop_last=True)
        # setup tracker
        losses_batch = AverageMeter()
        acc_batch = AverageMeter()

        self.model.train()

        for ep in range(self.epoch):
            for i, batch_data in enumerate(train_loader):
                # batch update
                batch_x, batch_y = batch_data
                batch_x = maybe_cuda(batch_x, self.cuda)
                batch_y = maybe_cuda(batch_y, self.cuda)
                self.model.learn(batch_x, batch_y)
                if self.params.verbose:
                    print('\r[Step {:4}] STM: {:5}/{} | #Expert: {}'.format(
                        i, len(self.model.stm_x), self.params.stm_capacity,
                        len(self.model.experts) - 1),
                          end='')
        print()
Esempio n. 2
0
 def train_learner(self, x_train, y_train):
     self.before_train(x_train, y_train)
     # set up loader
     train_dataset = dataset_transform(
         x_train, y_train, transform=transforms_match[self.data])
     train_loader = data.DataLoader(train_dataset,
                                    batch_size=self.batch,
                                    shuffle=True,
                                    num_workers=0,
                                    drop_last=True)
     self.model.train()
     self.update_representation(train_loader)
     self.prev_model = copy.deepcopy(self.model)
     self.after_train()
Esempio n. 3
0
    def train_learner(self, x_train, y_train):
        self.before_train(x_train, y_train)

        # set up loader
        train_dataset = dataset_transform(x_train, y_train, transform=transforms_match[self.data])
        train_loader = data.DataLoader(train_dataset, batch_size=self.batch, shuffle=True, num_workers=0,
                                       drop_last=True)

        # set up model
        self.model = self.model.train()

        # setup tracker
        losses_batch = AverageMeter()
        acc_batch = AverageMeter()

        for ep in range(self.epoch):
            for i, batch_data in enumerate(train_loader):
                # batch update
                batch_x, batch_y = batch_data
                batch_x = maybe_cuda(batch_x, self.cuda)
                batch_y = maybe_cuda(batch_y, self.cuda)

                logits = self.forward(batch_x)
                loss_old = self.kd_manager.get_kd_loss(logits, batch_x)
                loss_new = self.criterion(logits, batch_y)
                loss = 1/(self.task_seen + 1) * loss_new + (1 - 1/(self.task_seen + 1)) * loss_old
                _, pred_label = torch.max(logits, 1)
                correct_cnt = (pred_label == batch_y).sum().item() / batch_y.size(0)
                # update tracker
                acc_batch.update(correct_cnt, batch_y.size(0))
                losses_batch.update(loss, batch_y.size(0))
                # backward
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()

                if i % 100 == 1 and self.verbose:
                    print(
                        '==>>> it: {}, avg. loss: {:.6f}, '
                        'running train acc: {:.3f}'
                            .format(i, losses_batch.avg(), acc_batch.avg())
                    )
        self.after_train()
    def train_learner(self, x_train, y_train):
        self.before_train(x_train, y_train)
        # set up loader
        train_dataset = dataset_transform(
            x_train, y_train, transform=transforms_match[self.data])
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=self.batch,
                                       shuffle=True,
                                       num_workers=0,
                                       drop_last=True)

        for i, batch_data in enumerate(train_loader):
            # batch update
            batch_x, batch_y = batch_data
            batch_x = maybe_cuda(batch_x, self.cuda)
            batch_y = maybe_cuda(batch_y, self.cuda)
            # update mem
            for j in range(len(batch_x)):
                self.greedy_balancing_update(batch_x[j], batch_y[j].item())
        #self.early_stopping.reset()
        self.train_mem()
        self.after_train()
    def train_learner(self, x_train, y_train):
        self.before_train(x_train, y_train)
        # set up loader
        train_dataset = dataset_transform(
            x_train, y_train, transform=transforms_match[self.data])
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=self.batch,
                                       shuffle=True,
                                       num_workers=0,
                                       drop_last=True)
        # set up model
        self.model = self.model.train()

        # setup tracker
        losses_batch = AverageMeter()
        losses_mem = AverageMeter()
        acc_batch = AverageMeter()
        acc_mem = AverageMeter()

        for ep in range(self.epoch):
            for i, batch_data in enumerate(train_loader):
                # batch update
                batch_x, batch_y = batch_data
                batch_x = maybe_cuda(batch_x, self.cuda)
                batch_y = maybe_cuda(batch_y, self.cuda)
                for j in range(self.mem_iters):
                    logits = self.model.forward(batch_x)
                    loss = self.criterion(logits, batch_y)
                    if self.params.trick['kd_trick']:
                        loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
                                   self.kd_manager.get_kd_loss(logits, batch_x)
                    if self.params.trick['kd_trick_star']:
                        loss = 1/((self.task_seen + 1) ** 0.5) * loss + \
                               (1 - 1/((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x)
                    _, pred_label = torch.max(logits, 1)
                    correct_cnt = (pred_label
                                   == batch_y).sum().item() / batch_y.size(0)
                    # update tracker
                    acc_batch.update(correct_cnt, batch_y.size(0))
                    losses_batch.update(loss, batch_y.size(0))
                    # backward
                    self.opt.zero_grad()
                    loss.backward()

                    # mem update
                    mem_x, mem_y = self.buffer.retrieve(x=batch_x, y=batch_y)
                    if mem_x.size(0) > 0:
                        mem_x = maybe_cuda(mem_x, self.cuda)
                        mem_y = maybe_cuda(mem_y, self.cuda)
                        mem_logits = self.model.forward(mem_x)
                        loss_mem = self.criterion(mem_logits, mem_y)
                        if self.params.trick['kd_trick']:
                            loss_mem = 1 / (self.task_seen + 1) * loss_mem + (1 - 1 / (self.task_seen + 1)) * \
                                       self.kd_manager.get_kd_loss(mem_logits, mem_x)
                        if self.params.trick['kd_trick_star']:
                            loss_mem = 1 / ((self.task_seen + 1) ** 0.5) * loss_mem + \
                                   (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(mem_logits,
                                                                                                         mem_x)
                        # update tracker
                        losses_mem.update(loss_mem, mem_y.size(0))
                        _, pred_label = torch.max(mem_logits, 1)
                        correct_cnt = (pred_label
                                       == mem_y).sum().item() / mem_y.size(0)
                        acc_mem.update(correct_cnt, mem_y.size(0))

                        loss_mem.backward()

                    if self.params.agent == 'ASER':
                        # opt update
                        self.opt.zero_grad()
                        combined_batch = torch.cat((mem_x, batch_x))
                        combined_labels = torch.cat((mem_y, batch_y))
                        combined_logits = self.model.forward(combined_batch)
                        loss_combined = self.criterion(combined_logits,
                                                       combined_labels)
                        loss_combined.backward()
                        self.opt.step()
                    else:
                        self.opt.step()

                # update mem
                self.buffer.update(batch_x, batch_y)

                if i % 100 == 1 and self.verbose:
                    print('==>>> it: {}, avg. loss: {:.6f}, '
                          'running train acc: {:.3f}'.format(
                              i, losses_batch.avg(), acc_batch.avg()))
                    print('==>>> it: {}, mem avg. loss: {:.6f}, '
                          'running mem acc: {:.3f}'.format(
                              i, losses_mem.avg(), acc_mem.avg()))
        self.after_train()
    def train_learner(self, x_train, y_train):
        self.before_train(x_train, y_train)
        # set up loader
        train_dataset = dataset_transform(
            x_train, y_train, transform=transforms_match[self.data])
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=self.batch,
                                       shuffle=True,
                                       num_workers=0,
                                       drop_last=True)
        # setup tracker
        losses_batch = AverageMeter()
        acc_batch = AverageMeter()

        # set up model
        self.model.train()

        for ep in range(self.epoch):
            for i, batch_data in enumerate(train_loader):
                # batch update
                batch_x, batch_y = batch_data
                batch_x = maybe_cuda(batch_x, self.cuda)
                batch_y = maybe_cuda(batch_y, self.cuda)

                # update the running fisher
                if (ep * len(train_loader) + i +
                        1) % self.fisher_update_after == 0:
                    self.update_running_fisher()

                out = self.forward(batch_x)
                loss = self.total_loss(out, batch_y)
                if self.params.trick['kd_trick']:
                    loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
                                   self.kd_manager.get_kd_loss(out, batch_x)
                if self.params.trick['kd_trick_star']:
                    loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \
                           (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(out, batch_x)
                # update tracker
                losses_batch.update(loss.item(), batch_y.size(0))
                _, pred_label = torch.max(out, 1)
                acc = (pred_label == batch_y).sum().item() / batch_y.size(0)
                acc_batch.update(acc, batch_y.size(0))
                # backward
                self.opt.zero_grad()
                loss.backward()

                # accumulate the fisher of current batch
                self.accum_fisher()
                self.opt.step()

                if i % 100 == 1 and self.verbose:
                    print('==>>> it: {}, avg. loss: {:.6f}, '
                          'running train acc: {:.3f}'.format(
                              i, losses_batch.avg(), acc_batch.avg()))

        # save params for current task
        for n, p in self.weights.items():
            self.prev_params[n] = p.clone().detach()

        # update normalized fisher of current task
        max_fisher = max([torch.max(m) for m in self.running_fisher.values()])
        min_fisher = min([torch.min(m) for m in self.running_fisher.values()])
        for n, p in self.running_fisher.items():
            self.normalized_fisher[n] = (p - min_fisher) / (max_fisher -
                                                            min_fisher + 1e-32)
        self.after_train()
Esempio n. 7
0
    def train_learner(self, x_train, y_train):
        self.before_train(x_train, y_train)

        # set up loader
        train_dataset = dataset_transform(
            x_train, y_train, transform=transforms_match[self.data])
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=self.batch,
                                       shuffle=True,
                                       num_workers=0,
                                       drop_last=True)
        # set up model
        self.model = self.model.train()

        # setup tracker
        losses_batch = AverageMeter()
        acc_batch = AverageMeter()

        for ep in range(self.epoch):
            for i, batch_data in enumerate(train_loader):
                # batch update
                batch_x, batch_y = batch_data
                batch_x = maybe_cuda(batch_x, self.cuda)
                batch_y = maybe_cuda(batch_y, self.cuda)
                for j in range(self.mem_iters):
                    logits = self.forward(batch_x)
                    loss = self.criterion(logits, batch_y)
                    if self.params.trick['kd_trick']:
                        loss = 1 / (self.task_seen + 1) * loss + (1 - 1 / (self.task_seen + 1)) * \
                                    self.kd_manager.get_kd_loss(logits, batch_x)
                    if self.params.trick['kd_trick_star']:
                        loss = 1 / ((self.task_seen + 1) ** 0.5) * loss + \
                               (1 - 1 / ((self.task_seen + 1) ** 0.5)) * self.kd_manager.get_kd_loss(logits, batch_x)
                    _, pred_label = torch.max(logits, 1)
                    correct_cnt = (pred_label
                                   == batch_y).sum().item() / batch_y.size(0)
                    # update tracker
                    acc_batch.update(correct_cnt, batch_y.size(0))
                    losses_batch.update(loss, batch_y.size(0))
                    # backward
                    self.opt.zero_grad()
                    loss.backward()

                    if self.task_seen > 0:
                        # sample from memory of previous tasks
                        mem_x, mem_y = self.buffer.retrieve()
                        if mem_x.size(0) > 0:
                            params = [
                                p for p in self.model.parameters()
                                if p.requires_grad
                            ]
                            # gradient computed using current batch
                            grad = [p.grad.clone() for p in params]
                            mem_x = maybe_cuda(mem_x, self.cuda)
                            mem_y = maybe_cuda(mem_y, self.cuda)
                            mem_logits = self.forward(mem_x)
                            loss_mem = self.criterion(mem_logits, mem_y)
                            self.opt.zero_grad()
                            loss_mem.backward()
                            # gradient computed using memory samples
                            grad_ref = [p.grad.clone() for p in params]

                            # inner product of grad and grad_ref
                            prod = sum([
                                torch.sum(g * g_r)
                                for g, g_r in zip(grad, grad_ref)
                            ])
                            if prod < 0:
                                prod_ref = sum(
                                    [torch.sum(g_r**2) for g_r in grad_ref])
                                # do projection
                                grad = [
                                    g - prod / prod_ref * g_r
                                    for g, g_r in zip(grad, grad_ref)
                                ]
                            # replace params' grad
                            for g, p in zip(grad, params):
                                p.grad.data.copy_(g)
                    self.opt.step()
                # update mem
                self.buffer.update(batch_x, batch_y)

                if i % 100 == 1 and self.verbose:
                    print('==>>> it: {}, avg. loss: {:.6f}, '
                          'running train acc: {:.3f}'.format(
                              i, losses_batch.avg(), acc_batch.avg()))
        self.after_train()