class Er(ContinualModel): NAME = 'er' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] def __init__(self, backbone, loss, args, transform): super(Er, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) def observe(self, inputs, labels, not_aug_inputs): real_batch_size = inputs.shape[0] self.opt.zero_grad() if not self.buffer.is_empty(): buf_inputs, buf_labels = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) inputs = torch.cat((inputs, buf_inputs)) labels = torch.cat((labels, buf_labels)) outputs = self.net(inputs) loss = self.loss(outputs, labels) loss.backward() self.opt.step() self.buffer.add_data(examples=not_aug_inputs, labels=labels[:real_batch_size]) return loss.item()
class Der(ContinualModel): NAME = 'der' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] def __init__(self, backbone, loss, args, transform): super(Der, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) def observe(self, inputs, labels, not_aug_inputs): self.opt.zero_grad() outputs = self.net(inputs) loss = self.loss(outputs, labels) if not self.buffer.is_empty(): buf_inputs, buf_logits = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) buf_outputs = self.net(buf_inputs) loss += self.args.alpha * F.mse_loss(buf_outputs, buf_logits) loss.backward() self.opt.step() self.buffer.add_data(examples=not_aug_inputs, logits=outputs.data) return loss.item()
class Mer(ContinualModel): NAME = 'mer' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] def __init__(self, backbone, loss, args, transform): super(Mer, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) def draw_batches(self, inp, lab): batches = [] for i in range(self.args.batch_num): if not self.buffer.is_empty(): buf_inputs, buf_labels = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) inputs = torch.cat((buf_inputs, inp.unsqueeze(0))) labels = torch.cat( (buf_labels, torch.tensor([lab]).to(self.device))) batches.append((inputs, labels)) else: batches.append( (inp.unsqueeze(0), torch.tensor([lab]).unsqueeze(0).to(self.device))) return batches def observe(self, inputs, labels, not_aug_inputs): batches = self.draw_batches(inputs, labels) theta_A0 = self.net.get_params().data.clone() for i in range(self.args.batch_num): theta_Wi0 = self.net.get_params().data.clone() batch_inputs, batch_labels = batches[i] # within-batch step self.opt.zero_grad() outputs = self.net(batch_inputs) loss = self.loss(outputs, batch_labels.squeeze(-1)) loss.backward() self.opt.step() # within batch reptile meta-update new_params = theta_Wi0 + self.args.beta * (self.net.get_params() - theta_Wi0) self.net.set_params(new_params) self.buffer.add_data(examples=not_aug_inputs.unsqueeze(0), labels=labels) # across batch reptile meta-update new_new_params = theta_A0 + self.args.gamma * (self.net.get_params() - theta_A0) self.net.set_params(new_new_params) return loss.item()
class AGem(ContinualModel): NAME = 'agem' COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] def __init__(self, backbone, loss, args, transform): super(AGem, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) self.grad_dims = [] for param in self.parameters(): self.grad_dims.append(param.data.numel()) self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device) self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device) self.transform = transform if self.args.iba else None def end_task(self, dataset): samples_per_task = self.args.buffer_size // dataset.N_TASKS loader = dataset.not_aug_dataloader(self.args, samples_per_task) cur_x, cur_y = next(iter(loader))[:2] self.buffer.add_data(examples=cur_x.to(self.device), labels=cur_y.to(self.device)) def observe(self, inputs, labels, not_aug_inputs): self.zero_grad() p = self.net.forward(inputs) loss = self.loss(p, labels) loss.backward() if not self.buffer.is_empty(): store_grad(self.parameters, self.grad_xy, self.grad_dims) buf_inputs, buf_labels = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) self.net.zero_grad() buf_outputs = self.net.forward(buf_inputs) penalty = self.loss(buf_outputs, buf_labels) penalty.backward() store_grad(self.parameters, self.grad_er, self.grad_dims) dot_prod = torch.dot(self.grad_xy, self.grad_er) if dot_prod.item() < 0: g_tilde = project(gxy=self.grad_xy, ger=self.grad_er) overwrite_grad(self.parameters, g_tilde, self.grad_dims) else: overwrite_grad(self.parameters, self.grad_xy, self.grad_dims) self.opt.step() return loss.item()
class AGemr(ContinualModel): NAME = 'agem_r' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] def __init__(self, backbone, loss, args, transform): super(AGemr, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) self.grad_dims = [] for param in self.parameters(): self.grad_dims.append(param.data.numel()) self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device) self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device) self.current_task = 0 def observe(self, inputs, labels, not_aug_inputs): self.zero_grad() p = self.net.forward(inputs) loss = self.loss(p, labels) loss.backward() if not self.buffer.is_empty(): store_grad(self.parameters, self.grad_xy, self.grad_dims) buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size) self.net.zero_grad() buf_outputs = self.net.forward(buf_inputs) penalty = self.loss(buf_outputs, buf_labels) penalty.backward() store_grad(self.parameters, self.grad_er, self.grad_dims) dot_prod = torch.dot(self.grad_xy, self.grad_er) if dot_prod.item() < 0: g_tilde = project(gxy=self.grad_xy, ger=self.grad_er) overwrite_grad(self.parameters, g_tilde, self.grad_dims) else: overwrite_grad(self.parameters, self.grad_xy, self.grad_dims) self.opt.step() self.buffer.add_data(examples=not_aug_inputs, labels=labels) return loss.item()
class Gem(ContinualModel): NAME = 'gem' COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] def __init__(self, backbone, loss, args, transform): super(Gem, self).__init__(backbone, loss, args, transform) self.current_task = 0 self.buffer = Buffer(self.args.buffer_size, self.device) self.transform = transform # Allocate temporary synaptic memory self.grad_dims = [] for pp in self.parameters(): self.grad_dims.append(pp.data.numel()) self.grads_cs = [] self.grads_da = torch.zeros(np.sum(self.grad_dims)).to(self.device) self.transform = transform if self.args.iba else None def end_task(self, dataset): self.current_task += 1 self.grads_cs.append( torch.zeros(np.sum(self.grad_dims)).to(self.device)) # add data to the buffer samples_per_task = self.args.buffer_size // dataset.N_TASKS loader = dataset.not_aug_dataloader(self.args, samples_per_task) cur_x, cur_y = next(iter(loader))[:2] self.buffer.add_data( examples=cur_x.to(self.device), labels=cur_y.to(self.device), task_labels=torch.ones(samples_per_task, dtype=torch.long).to( self.device) * (self.current_task - 1)) def observe(self, inputs, labels, not_aug_inputs): if not self.buffer.is_empty(): buf_inputs, buf_labels, buf_task_labels = self.buffer.get_data( self.args.buffer_size, transform=self.transform) for tt in buf_task_labels.unique(): # compute gradient on the memory buffer self.opt.zero_grad() cur_task_inputs = buf_inputs[buf_task_labels == tt] cur_task_labels = buf_labels[buf_task_labels == tt] for i in range( math.ceil(len(cur_task_inputs) / self.args.batch_size)): cur_task_outputs = self.forward( cur_task_inputs[i * self.args.batch_size:(i + 1) * self.args.batch_size]) penalty = self.loss( cur_task_outputs, cur_task_labels[i * self.args.batch_size:(i + 1) * self.args.batch_size], reduction='sum') / cur_task_inputs.shape[0] penalty.backward() store_grad(self.parameters, self.grads_cs[tt], self.grad_dims) # cur_task_outputs = self.forward(cur_task_inputs) # penalty = self.loss(cur_task_outputs, cur_task_labels) # penalty.backward() # store_grad(self.parameters, self.grads_cs[tt], self.grad_dims) # now compute the grad on the current data self.opt.zero_grad() outputs = self.forward(inputs) loss = self.loss(outputs, labels) loss.backward() # check if gradient violates buffer constraints if not self.buffer.is_empty(): # copy gradient store_grad(self.parameters, self.grads_da, self.grad_dims) dot_prod = torch.mm(self.grads_da.unsqueeze(0), torch.stack(self.grads_cs).T) if (dot_prod < 0).sum() != 0: project2cone2(self.grads_da.unsqueeze(1), torch.stack(self.grads_cs).T, margin=self.args.gamma) # copy gradients back overwrite_grad(self.parameters, self.grads_da, self.grad_dims) self.opt.step() return loss.item()
class HAL(ContinualModel): NAME = 'hal' COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] def __init__(self, backbone, loss, args, transform): super(HAL, self).__init__(backbone, loss, args, transform) self.task_number = 0 self.buffer = Buffer(self.args.buffer_size, self.device, get_dataset(args).N_TASKS, mode='ring') self.hal_lambda = args.hal_lambda self.beta = args.beta self.gamma = args.gamma self.anchor_optimization_steps = 100 self.finetuning_epochs = 1 self.dataset = get_dataset(args) self.spare_model = self.dataset.get_backbone() self.spare_model.to(self.device) self.spare_opt = SGD(self.spare_model.parameters(), lr=self.args.lr) def end_task(self, dataset): self.task_number += 1 # ring buffer mgmt (if we are not loading if self.task_number > self.buffer.task_number: self.buffer.num_seen_examples = 0 self.buffer.task_number = self.task_number # get anchors (provided that we are not loading the model if len(self.anchors) < self.task_number * dataset.N_CLASSES_PER_TASK: self.get_anchors(dataset) del self.phi def get_anchors(self, dataset): theta_t = self.net.get_params().detach().clone() self.spare_model.set_params(theta_t) # fine tune on memory buffer for _ in range(self.finetuning_epochs): inputs, labels = self.buffer.get_data(self.args.batch_size, transform=self.transform) self.spare_opt.zero_grad() out = self.spare_model(inputs) loss = self.loss(out, labels) loss.backward() self.spare_opt.step() theta_m = self.spare_model.get_params().detach().clone() classes_for_this_task = np.unique(dataset.train_loader.dataset.targets) for a_class in classes_for_this_task: e_t = torch.rand(self.input_shape, requires_grad=True, device=self.device) e_t_opt = SGD([e_t], lr=self.args.lr) print(file=sys.stderr) for i in range(self.anchor_optimization_steps): e_t_opt.zero_grad() cum_loss = 0 self.spare_opt.zero_grad() self.spare_model.set_params(theta_m.detach().clone()) loss = -torch.sum( self.loss(self.spare_model(e_t.unsqueeze(0)), torch.tensor([a_class]).to(self.device))) loss.backward() cum_loss += loss.item() self.spare_opt.zero_grad() self.spare_model.set_params(theta_t.detach().clone()) loss = torch.sum( self.loss(self.spare_model(e_t.unsqueeze(0)), torch.tensor([a_class]).to(self.device))) loss.backward() cum_loss += loss.item() self.spare_opt.zero_grad() loss = torch.sum(self.gamma * (self.spare_model.features(e_t.unsqueeze(0)) - self.phi)**2) assert not self.phi.requires_grad loss.backward() cum_loss += loss.item() e_t_opt.step() e_t = e_t.detach() e_t.requires_grad = False self.anchors = torch.cat((self.anchors, e_t.unsqueeze(0))) del e_t print('Total anchors:', len(self.anchors), file=sys.stderr) self.spare_model.zero_grad() def observe(self, inputs, labels, not_aug_inputs): real_batch_size = inputs.shape[0] if not hasattr(self, 'input_shape'): self.input_shape = inputs.shape[1:] if not hasattr(self, 'anchors'): self.anchors = torch.zeros(tuple([0] + list(self.input_shape))).to( self.device) if not hasattr(self, 'phi'): print('Building phi', file=sys.stderr) with torch.no_grad(): self.phi = torch.zeros_like(self.net.features( inputs[0].unsqueeze(0)), requires_grad=False) assert not self.phi.requires_grad if not self.buffer.is_empty(): buf_inputs, buf_labels = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) inputs = torch.cat((inputs, buf_inputs)) labels = torch.cat((labels, buf_labels)) old_weights = self.net.get_params().detach().clone() self.opt.zero_grad() outputs = self.net(inputs) k = self.task_number loss = self.loss(outputs, labels) loss.backward() self.opt.step() first_loss = 0 assert len(self.anchors) == self.dataset.N_CLASSES_PER_TASK * k if len(self.anchors) > 0: first_loss = loss.item() with torch.no_grad(): pred_anchors = self.net(self.anchors) self.net.set_params(old_weights) pred_anchors -= self.net(self.anchors) loss = self.hal_lambda * (pred_anchors**2).mean() loss.backward() self.opt.step() with torch.no_grad(): self.phi = self.beta * self.phi + ( 1 - self.beta) * self.net.features( inputs[:real_batch_size]).mean(0) self.buffer.add_data(examples=not_aug_inputs, labels=labels[:real_batch_size]) return first_loss + loss.item()
def fill_buffer(self, mem_buffer: Buffer, dataset, t_idx: int) -> None: """ Adds examples from the current task to the memory buffer by means of the herding strategy. :param mem_buffer: the memory buffer :param dataset: the dataset from which take the examples :param t_idx: the task index """ mode = self.net.training self.net.eval() samples_per_class = mem_buffer.buffer_size // len(self.classes_so_far) if t_idx > 0: # 1) First, subsample prior classes buf_x, buf_y, buf_l = self.buffer.get_all_data() mem_buffer.empty() for _y in buf_y.unique(): idx = (buf_y == _y) _y_x, _y_y, _y_l = buf_x[idx], buf_y[idx], buf_l[idx] mem_buffer.add_data(examples=_y_x[:samples_per_class], labels=_y_y[:samples_per_class], logits=_y_l[:samples_per_class]) # 2) Then, fill with current tasks loader = dataset.not_aug_dataloader(self.args, self.args.batch_size) # 2.1 Extract all features a_x, a_y, a_f, a_l = [], [], [], [] for x, y, not_norm_x in loader: x, y, not_norm_x = (a.to(self.device) for a in [x, y, not_norm_x]) a_x.append(not_norm_x.to('cpu')) a_y.append(y.to('cpu')) feats = self.net.features(x) a_f.append(feats.cpu()) a_l.append(torch.sigmoid(self.net.classifier(feats)).cpu()) a_x, a_y, a_f, a_l = torch.cat(a_x), torch.cat(a_y), torch.cat( a_f), torch.cat(a_l) # 2.2 Compute class means for _y in a_y.unique(): idx = (a_y == _y) _x, _y, _l = a_x[idx], a_y[idx], a_l[idx] feats = a_f[idx] mean_feat = feats.mean(0, keepdim=True) running_sum = torch.zeros_like(mean_feat) i = 0 while i < samples_per_class and i < feats.shape[0]: cost = (mean_feat - (feats + running_sum) / (i + 1)).norm(2, 1) idx_min = cost.argmin().item() mem_buffer.add_data( examples=_x[idx_min:idx_min + 1].to(self.device), labels=_y[idx_min:idx_min + 1].to(self.device), logits=_l[idx_min:idx_min + 1].to(self.device)) running_sum += feats[idx_min:idx_min + 1] feats[idx_min] = feats[idx_min] + 1e6 i += 1 assert len(mem_buffer.examples) <= mem_buffer.buffer_size self.net.train(mode)
class Fdr(ContinualModel): NAME = 'fdr' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] def __init__(self, backbone, loss, args, transform): super(Fdr, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) self.current_task = 0 self.i = 0 self.soft = torch.nn.Softmax(dim=1) self.logsoft = torch.nn.LogSoftmax(dim=1) def end_task(self, dataset): self.current_task += 1 examples_per_task = self.args.buffer_size // self.current_task if self.current_task > 1: buf_x, buf_log, buf_tl = self.buffer.get_all_data() self.buffer.empty() for ttl in buf_tl.unique(): idx = (buf_tl == ttl) ex, log, tasklab = buf_x[idx], buf_log[idx], buf_tl[idx] first = min(ex.shape[0], examples_per_task) self.buffer.add_data(examples=ex[:first], logits=log[:first], task_labels=tasklab[:first]) counter = 0 with torch.no_grad(): for i, data in enumerate(dataset.train_loader): inputs, labels, not_aug_inputs = data inputs = inputs.to(self.device) not_aug_inputs = not_aug_inputs.to(self.device) outputs = self.net(inputs) if examples_per_task - counter < 0: break self.buffer.add_data( examples=not_aug_inputs[:(examples_per_task - counter)], logits=outputs.data[:(examples_per_task - counter)], task_labels=(torch.ones(self.args.batch_size) * (self.current_task - 1))[:(examples_per_task - counter)]) counter += self.args.batch_size def observe(self, inputs, labels, not_aug_inputs): self.i += 1 self.opt.zero_grad() outputs = self.net(inputs) loss = self.loss(outputs, labels) loss.backward() self.opt.step() if not self.buffer.is_empty(): self.opt.zero_grad() buf_inputs, buf_logits, _ = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) buf_outputs = self.net(buf_inputs) loss = torch.norm( self.soft(buf_outputs) - self.soft(buf_logits), 2, 1).mean() assert not torch.isnan(loss) loss.backward() self.opt.step() return loss.item()
def train_cl(train_set, test_set, model, loss, optimizer, device, config): """ :param train_set: Train set :param test_set: Test set :param model: PyTorch model :param loss: loss function :param optimizer: optimizer :param device: device cuda/cpu :param config: configuration """ name = "" # global_writer = SummaryWriter('./runs/continual/train/global/' + datetime.datetime.now().strftime('%m_%d_%H_%M')) global_writer = SummaryWriter('./runs/continual/train/global/' + name) buffer = Buffer(config['buffer_size'], device) accuracy = [] text = open("result_" + name + ".txt", "w") # TODO save results in a .txt file # Eval without training random_accuracy = evaluate_past(model, len(test_set) - 1, test_set, loss, device) text.write("Evaluation before training" + '\n') for a in random_accuracy: text.write(f"{a:.2f}% ") text.write('\n') for index, data_set in enumerate(train_set): model.train() print(f"----- DOMAIN {index} -----") print("Training model...") train_loader = DataLoader(data_set, batch_size=config['batch_size'], shuffle=False) for epoch in tqdm(range(config['epochs'])): epoch_loss = [] epoch_acc = [] for i, (x, y) in enumerate(train_loader): optimizer.zero_grad() inputs = x.to(device) labels = y.to(device) if not buffer.is_empty(): # Strategy 50/50 # From batch of 64 (dataloader) to 64 + 64 (dataloader + replay) buf_input, buf_label = buffer.get_data( config['batch_size']) inputs = torch.cat((inputs, torch.stack(buf_input))) labels = torch.cat((labels, torch.stack(buf_label))) y_pred = model(inputs) s_loss = loss(y_pred.squeeze(1), labels) acc = binary_accuracy(y_pred.squeeze(1), labels) # METRICHE INTERNE EPOCA epoch_loss.append(s_loss.item()) epoch_acc.append(acc.item()) s_loss.backward() optimizer.step() if epoch == 0: buffer.add_data(examples=x.to(device), labels=y.to(device)) global_writer.add_scalar('Train_global/Loss', statistics.mean(epoch_loss), epoch + (config['epochs'] * index)) global_writer.add_scalar('Train_global/Accuracy', statistics.mean(epoch_acc), epoch + (config['epochs'] * index)) # domain_writer.add_scalar(f'Train_D{index}/Loss', statistics.mean(epoch_loss), epoch) # domain_writer.add_scalar(f'Train_D{index}/Accuracy', statistics.mean(epoch_acc), epoch) if epoch % 100 == 0: print( f'\nEpoch {epoch:03}/{config["epochs"]} | Loss: {statistics.mean(epoch_loss):.5f} ' f'| Acc: {statistics.mean(epoch_acc):.5f}') # Last epoch (only for stats) if epoch == 499: print( f'\nEpoch {epoch:03}/{config["epochs"]} | Loss: {statistics.mean(epoch_loss):.5f} ' f'| Acc: {statistics.mean(epoch_acc):.5f}') # Test on domain just trained + old domains evaluation = evaluate_past(model, index, test_set, loss, device) accuracy.append(evaluation) text.write(f"Evaluation after domain {index}" + '\n') for a in evaluation: text.write(f"{a:.2f}% ") text.write('\n') if index != len(train_set) - 1: accuracy[index].append( evaluate_next(model, index, test_set, loss, device)) # Check buffer distribution buffer.check_distribution() # Compute transfer metrics backward = backward_transfer(accuracy) forward = forward_transfer(accuracy, random_accuracy) forget = forgetting(accuracy) print(f'Backward transfer: {backward}') # todo Sono in %? print(f'Forward transfer: {forward}') print(f'Forgetting: {forget}') text.write(f"Backward: {backward}\n") text.write(f"Forward: {forward}\n") text.write(f"Forgetting: {forget}\n") text.close()
class OCILFAST(ContinualModel): NAME = 'OCILFAST' COMPATIBILITY = ['class-il', 'task-il'] def __init__(self, net, loss, args, transform): super(OCILFAST, self).__init__(net, loss, args, transform) self.nets = [] self.c = [] self.threshold = [] self.nu = self.args.nu self.eta = self.args.eta self.eps = self.args.eps self.embedding_dim = self.args.embedding_dim self.weight_decay = self.args.weight_decay self.margin = self.args.margin self.current_task = 0 self.cpt = None self.nc = None self.eye = None self.buffer_size = self.args.buffer_size self.buffer = Buffer(self.args.buffer_size, self.device) self.nf = self.args.nf if self.args.dataset == 'seq-cifar10' or self.args.dataset == 'seq-mnist': self.input_offset = -0.5 elif self.args.dataset == 'seq-tinyimg': self.input_offset = 0 else: self.input_offset = 0 # 任务初始化 def begin_task(self, dataset): if self.cpt is None: self.cpt = dataset.N_CLASSES_PER_TASK self.nc = dataset.N_TASKS * self.cpt self.eye = torch.tril(torch.ones((self.nc, self.nc))).bool().to( self.device) # 下三角包括对角线为True,上三角为False,用于掩码 if len(self.nets) == 0: for i in range(self.nc): self.nets.append( get_backbone(self.net, self.embedding_dim, self.nc, self.nf).to(self.device)) self.c.append( torch.ones(self.embedding_dim, device=self.device)) self.current_task += 1 def train_model(self, dataset, train_loader): categories = list( range((self.current_task - 1) * self.cpt, (self.current_task) * self.cpt)) print('==========\t task: %d\t categories:' % self.current_task, categories, '\t==========') if self.args.print_file: print('==========\t task: %d\t categories:' % self.current_task, categories, '\t==========', file=self.args.print_file) for category in categories: losses = [] if category > 0: self.reset_train_loader(train_loader, category) for epoch in range(self.args.n_epochs): avg_loss, maxloss, posdist, negdist, gloloss = self.train_category( train_loader, category, epoch) losses.append(avg_loss) if epoch == 0 or (epoch + 1) % 5 == 0: print("epoch: %d\t task: %d \t category: %d \t loss: %f" % (epoch + 1, self.current_task, category, avg_loss)) if self.args.print_file: print( "epoch: %d\t task: %d \t category: %d \t loss: %f" % (epoch + 1, self.current_task, category, avg_loss), file=self.args.print_file) plt.figure(figsize=(20, 12)) ax = plt.subplot(2, 2, 1) ax.set_title('maxloss') plt.xlim((0, 2)) if maxloss is not None: try: sns.distplot(maxloss) except: pass ax = plt.subplot(2, 2, 2) ax.set_title('posdist') plt.xlim((0, 2)) try: sns.distplot(posdist) except: print(posdist) ax = plt.subplot(2, 2, 3) ax.set_title('negdist') plt.xlim((0, 2)) try: sns.distplot(negdist) except: print(negdist) ax = plt.subplot(2, 2, 4) ax.set_title('gloloss') plt.xlim((0, 2)) try: sns.distplot(gloloss) except: print(gloloss) plt.savefig("../" + self.args.img_dir + "/loss-cat%d-epoch%d.png" % (category, epoch)) plt.clf() x = list(range(len(losses))) plt.plot(x, losses) plt.savefig("../" + self.args.img_dir + "/loss-cat%d.png" % (category)) plt.clf() self.fill_buffer(train_loader) def reset_train_loader(self, train_loader, category): dataset = train_loader.dataset input = dataset.data loader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=False) inputs = [] targets = [] prev_dists = [] prev_categories = list(range(category)) print('prev_categories', prev_categories) if self.args.print_file: print('prev_categories', prev_categories, file=self.args.print_file) for i, data in enumerate(loader): input, target, _ = data _, prev_dist = self.predict(input, prev_categories) inputs.append(input.detach().cpu()) targets.append(target.detach().cpu()) prev_dists.append(prev_dist.detach().cpu()) inputs = torch.cat(inputs, dim=0) targets = torch.cat(targets, dim=0) prev_dists = torch.cat(prev_dists, dim=0) dataset.set_prevdist(prev_dists) def train_category(self, data_loader, category: int, epoch_id): self.init_center_c(data_loader, category) c = self.c[category] network = self.nets[category].to(self.device) network.train() optimizer = SGD(network.parameters(), lr=self.args.lr, weight_decay=self.weight_decay) avg_loss = 0.0 sample_num = 0 maxloss = [] posdist = [] negdist = [] gloloss = [] prev_categories = list(range(category)) for i, data in enumerate(data_loader): inputs, semi_targets, prev_dists = data inputs = inputs.to(self.device) semi_targets = semi_targets.to(self.device) prev_dists = prev_dists.to(self.device) if (not self.buffer.is_empty()) and self.args.buffer_size > 0: buf_inputs, buf_labels = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) # print(buf_inputs[0]) inputs = torch.cat((inputs, buf_inputs)) semi_targets = torch.cat((semi_targets, buf_labels)) # Zero the network parameter gradients optimizer.zero_grad() # 注意网络的输入要减去0.5 outputs = network(inputs + self.input_offset) dists = torch.sum((outputs - c)**2, dim=1) pos_dist_loss = torch.relu(dists - self.args.r) if category > 0: max_scores = torch.relu(dists.view(-1, 1) - prev_dists) max_loss = torch.sum(max_scores, dim=1) * self.margin / category loss_pos = pos_dist_loss + max_loss loss_neg = self.eta * dists**-1 pos_max_loss = max_loss[semi_targets == category] maxloss.append(pos_max_loss.detach().cpu().data.numpy()) else: loss_pos = pos_dist_loss loss_neg = self.eta * dists**-1 losses = torch.where(semi_targets == category, loss_pos, loss_neg) gloloss.append(losses.detach().cpu().data.numpy()) loss = torch.mean(losses) loss.backward() optimizer.step() # 记录损失部分 pos_dist = pos_dist_loss[semi_targets == category] posdist.append(pos_dist.detach().cpu().data.numpy()) neg_dist = loss_neg[semi_targets != category] negdist.append(neg_dist.detach().cpu().data.numpy()) avg_loss += loss.item() sample_num += inputs.shape[0] # 旧类别只训练一次 if category < (self.current_task - 1) * self.cpt: break avg_loss /= sample_num if len(maxloss) > 0: maxloss = np.hstack(maxloss) else: maxloss = None posdist = np.hstack(posdist) negdist = np.hstack(negdist) gloloss = np.hstack(gloloss) return avg_loss, maxloss, posdist, negdist, gloloss def fill_buffer(self, train_loader): for data in train_loader: # get the inputs of the batch inputs, semi_targets, not_aug_inputs = data self.buffer.add_data(examples=not_aug_inputs, labels=semi_targets) def init_center_c(self, train_loader: DataLoader, category): """Initialize hypersphere center c as the mean from an initial forward pass on the data.""" n_samples = 0 c = 0 net = self.nets[category].to(self.device) net.eval() with torch.no_grad(): for data in train_loader: # get the inputs of the batch inputs, semi_targets, not_aug_inputs = data inputs = inputs.to(self.device) semi_targets = semi_targets.to(self.device) outputs = net(inputs + self.input_offset) outputs = outputs[semi_targets == category] # 取所有正样本来进行圆心初始化 # print(outputs) n_samples += outputs.shape[0] c += torch.sum(outputs, dim=0) c /= n_samples # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights. c[(abs(c) < self.eps) & (c < 0)] = -self.eps c[(abs(c) < self.eps) & (c > 0)] = self.eps self.c[category] = c.to(self.device) def get_score(self, dist, category): score = 1 / (dist + 1e-6) return score def forward(self, x: torch.Tensor) -> torch.Tensor: categories = list(range(self.current_task * self.cpt)) return self.predict(x, categories)[0] def predict(self, inputs: torch.Tensor, categories): inputs = inputs.to(self.device) outcome, dists = [], [] with torch.no_grad(): for i in categories: net = self.nets[i] net.to(self.device) net.eval() c = self.c[i].to(self.device) pred = net(inputs + self.input_offset) dist = torch.sum((pred - c)**2, dim=1) scores = self.get_score(dist, i) outcome.append(scores.view(-1, 1)) dists.append(dist.view(-1, 1)) outcome = torch.cat(outcome, dim=1) dists = torch.cat(dists, dim=1) return outcome, dists