def cache_masks(self): self.register_buffer( "stacked", torch.stack([ module_util.get_subnet_fast(self.scores[j]) for j in range(pargs.num_tasks) ]), )
def cache_masks(self): with torch.no_grad(): d = self.d W = torch.zeros(d, d).to(pargs.device) for j in range(self.num_tasks_learned): x = 2 * module_util.get_subnet_fast(self.scores[j]) - 1 heb = torch.ger(x, x) - torch.eye(d).to(pargs.device) h = W.mm(x.unsqueeze(1)).squeeze() pre = torch.ger(x, h) W = W + (1.0 / d) * (heb - pre - pre.t()) # W = W + (1. / d) * heb self.register_buffer("W", W)
def forward(self, x): if self.task < 0: stacked = torch.stack([ module_util.get_subnet_fast(self.scores[j]) for j in range(min(pargs.num_tasks, self.num_tasks_learned)) ]) alpha_weights = self.alphas[:self.num_tasks_learned] subnet = (alpha_weights * stacked).sum(dim=0) else: subnet = module_util.GetSubnetFast.apply(self.scores[self.task]) w = self.weight * subnet x = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) return x
def hopfield_recovery( model, writer, test_loader, num_tasks_learned, task, ): model.zero_grad() model.train() # stopping time tracks how many epochs were required to adapt. stopping_time = 0 correct = 0 taskname = f"{args.set}_{task}" params = [] for n, m in model.named_modules(): if isinstance(m, FastHopMaskBN): out = torch.stack( [ 2 * module_util.get_subnet_fast(m.scores[j]) - 1 for j in range(m.num_tasks_learned) ] ) m.score = torch.nn.Parameter(out.mean(dim=0)) params.append(m.score) optimizer = optim.SGD( params, lr=500, momentum=args.momentum, weight_decay=args.wd, ) for batch_idx, (data_, target) in enumerate(test_loader): data, target = data_.to(args.device), target.to(args.device) hop_loss = None for n, m in model.named_modules(): if isinstance(m, FastHopMaskBN): s = 2 * module_util.GetSubnetFast.apply(m.score) - 1 target = 2 * module_util.get_subnet_fast(m.scores[task]) - 1 distance = (s != target).sum().item() writer.add_scalar( f"adapt_{taskname}/distance_{n}", distance, batch_idx + 1, ) optimizer.zero_grad() model.zero_grad() output = model(data) logit_entropy = ( -(output.softmax(dim=1) * output.log_softmax(dim=1)).sum(1).mean() ) for n, m in model.named_modules(): if isinstance(m, FastHopMaskBN): s = 2 * module_util.GetSubnetFast.apply(m.score) - 1 if hop_loss is None: hop_loss = ( -0.5 * s.unsqueeze(0).mm(m.W.mm(s.unsqueeze(1))).squeeze() ) else: hop_loss += ( -0.5 * s.unsqueeze(0).mm(m.W.mm(s.unsqueeze(1))).squeeze() ) hop_lr = args.gamma * ( float(batch_idx + 1) / len(test_loader) ) hop_loss = hop_lr * hop_loss ent_lr = 1 - (float(batch_idx + 1) / len(test_loader)) logit_entropy = logit_entropy * ent_lr (logit_entropy + hop_loss).backward() optimizer.step() writer.add_scalar( f"adapt_{taskname}/{num_tasks_learned}/entropy", logit_entropy.item(), batch_idx + 1, ) writer.add_scalar( f"adapt_{taskname}/{num_tasks_learned}/hop_loss", hop_loss.item(), batch_idx + 1, ) test_acc = adapt_test( model, test_loader, alphas=None, ) model.apply(lambda m: setattr(m, "alphas", None)) return test_acc