def fill_buffer(self, mem_buffer: Buffer, dataset, t_idx: int) -> None: """ Adds examples from the current task to the memory buffer by means of the herding strategy. :param mem_buffer: the memory buffer :param dataset: the dataset from which take the examples :param t_idx: the task index """ mode = self.net.training self.net.eval() samples_per_class = mem_buffer.buffer_size // len(self.classes_so_far) if t_idx > 0: # 1) First, subsample prior classes buf_x, buf_y, buf_l = self.buffer.get_all_data() mem_buffer.empty() for _y in buf_y.unique(): idx = (buf_y == _y) _y_x, _y_y, _y_l = buf_x[idx], buf_y[idx], buf_l[idx] mem_buffer.add_data(examples=_y_x[:samples_per_class], labels=_y_y[:samples_per_class], logits=_y_l[:samples_per_class]) # 2) Then, fill with current tasks loader = dataset.not_aug_dataloader(self.args, self.args.batch_size) # 2.1 Extract all features a_x, a_y, a_f, a_l = [], [], [], [] for x, y, not_norm_x in loader: x, y, not_norm_x = (a.to(self.device) for a in [x, y, not_norm_x]) a_x.append(not_norm_x.to('cpu')) a_y.append(y.to('cpu')) feats = self.net.features(x) a_f.append(feats.cpu()) a_l.append(torch.sigmoid(self.net.classifier(feats)).cpu()) a_x, a_y, a_f, a_l = torch.cat(a_x), torch.cat(a_y), torch.cat( a_f), torch.cat(a_l) # 2.2 Compute class means for _y in a_y.unique(): idx = (a_y == _y) _x, _y, _l = a_x[idx], a_y[idx], a_l[idx] feats = a_f[idx] mean_feat = feats.mean(0, keepdim=True) running_sum = torch.zeros_like(mean_feat) i = 0 while i < samples_per_class and i < feats.shape[0]: cost = (mean_feat - (feats + running_sum) / (i + 1)).norm(2, 1) idx_min = cost.argmin().item() mem_buffer.add_data( examples=_x[idx_min:idx_min + 1].to(self.device), labels=_y[idx_min:idx_min + 1].to(self.device), logits=_l[idx_min:idx_min + 1].to(self.device)) running_sum += feats[idx_min:idx_min + 1] feats[idx_min] = feats[idx_min] + 1e6 i += 1 assert len(mem_buffer.examples) <= mem_buffer.buffer_size self.net.train(mode)
class Fdr(ContinualModel): NAME = 'fdr' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] def __init__(self, backbone, loss, args, transform): super(Fdr, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) self.current_task = 0 self.i = 0 self.soft = torch.nn.Softmax(dim=1) self.logsoft = torch.nn.LogSoftmax(dim=1) def end_task(self, dataset): self.current_task += 1 examples_per_task = self.args.buffer_size // self.current_task if self.current_task > 1: buf_x, buf_log, buf_tl = self.buffer.get_all_data() self.buffer.empty() for ttl in buf_tl.unique(): idx = (buf_tl == ttl) ex, log, tasklab = buf_x[idx], buf_log[idx], buf_tl[idx] first = min(ex.shape[0], examples_per_task) self.buffer.add_data(examples=ex[:first], logits=log[:first], task_labels=tasklab[:first]) counter = 0 with torch.no_grad(): for i, data in enumerate(dataset.train_loader): inputs, labels, not_aug_inputs = data inputs = inputs.to(self.device) not_aug_inputs = not_aug_inputs.to(self.device) outputs = self.net(inputs) if examples_per_task - counter < 0: break self.buffer.add_data( examples=not_aug_inputs[:(examples_per_task - counter)], logits=outputs.data[:(examples_per_task - counter)], task_labels=(torch.ones(self.args.batch_size) * (self.current_task - 1))[:(examples_per_task - counter)]) counter += self.args.batch_size def observe(self, inputs, labels, not_aug_inputs): self.i += 1 self.opt.zero_grad() outputs = self.net(inputs) loss = self.loss(outputs, labels) loss.backward() self.opt.step() if not self.buffer.is_empty(): self.opt.zero_grad() buf_inputs, buf_logits, _ = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) buf_outputs = self.net(buf_inputs) loss = torch.norm( self.soft(buf_outputs) - self.soft(buf_logits), 2, 1).mean() assert not torch.isnan(loss) loss.backward() self.opt.step() return loss.item()