Esempio n. 1
0
 def cache_masks(self):
     self.register_buffer(
         "stacked",
         torch.stack([
             module_util.get_subnet_fast(self.scores[j])
             for j in range(pargs.num_tasks)
         ]),
     )
Esempio n. 2
0
    def cache_masks(self):
        with torch.no_grad():
            d = self.d
            W = torch.zeros(d, d).to(pargs.device)
            for j in range(self.num_tasks_learned):
                x = 2 * module_util.get_subnet_fast(self.scores[j]) - 1
                heb = torch.ger(x, x) - torch.eye(d).to(pargs.device)
                h = W.mm(x.unsqueeze(1)).squeeze()
                pre = torch.ger(x, h)
                W = W + (1.0 / d) * (heb - pre - pre.t())
                # W = W + (1. / d) * heb

            self.register_buffer("W", W)
Esempio n. 3
0
 def forward(self, x):
     if self.task < 0:
         stacked = torch.stack([
             module_util.get_subnet_fast(self.scores[j])
             for j in range(min(pargs.num_tasks, self.num_tasks_learned))
         ])
         alpha_weights = self.alphas[:self.num_tasks_learned]
         subnet = (alpha_weights * stacked).sum(dim=0)
     else:
         subnet = module_util.GetSubnetFast.apply(self.scores[self.task])
     w = self.weight * subnet
     x = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation,
                  self.groups)
     return x
Esempio n. 4
0
def hopfield_recovery(
    model, writer, test_loader, num_tasks_learned, task,
):
    model.zero_grad()
    model.train()
    # stopping time tracks how many epochs were required to adapt.
    stopping_time = 0
    correct = 0
    taskname = f"{args.set}_{task}"


    params = []
    for n, m in model.named_modules():
        if isinstance(m, FastHopMaskBN):
            out = torch.stack(
                [
                    2 * module_util.get_subnet_fast(m.scores[j]) - 1
                    for j in range(m.num_tasks_learned)
                ]
            )

            m.score = torch.nn.Parameter(out.mean(dim=0))
            params.append(m.score)

    optimizer = optim.SGD(
        params, lr=500, momentum=args.momentum, weight_decay=args.wd,
    )

    for batch_idx, (data_, target) in enumerate(test_loader):
        data, target = data_.to(args.device), target.to(args.device)
        hop_loss = None

        for n, m in model.named_modules():
            if isinstance(m, FastHopMaskBN):
                s = 2 * module_util.GetSubnetFast.apply(m.score) - 1
                target = 2 * module_util.get_subnet_fast(m.scores[task]) - 1
                distance = (s != target).sum().item()
                writer.add_scalar(
                    f"adapt_{taskname}/distance_{n}",
                    distance,
                    batch_idx + 1,
                )
        optimizer.zero_grad()
        model.zero_grad()
        output = model(data)
        logit_entropy = (
            -(output.softmax(dim=1) * output.log_softmax(dim=1)).sum(1).mean()
        )
        for n, m in model.named_modules():
            if isinstance(m, FastHopMaskBN):
                s = 2 * module_util.GetSubnetFast.apply(m.score) - 1
                if hop_loss is None:
                    hop_loss = (
                        -0.5 * s.unsqueeze(0).mm(m.W.mm(s.unsqueeze(1))).squeeze()
                    )
                else:
                    hop_loss += (
                        -0.5 * s.unsqueeze(0).mm(m.W.mm(s.unsqueeze(1))).squeeze()
                    )

        hop_lr = args.gamma * (
            float(batch_idx + 1) / len(test_loader)
        )
        hop_loss =  hop_lr * hop_loss
        ent_lr = 1 - (float(batch_idx + 1) / len(test_loader))
        logit_entropy = logit_entropy * ent_lr
        (logit_entropy + hop_loss).backward()
        optimizer.step()

        writer.add_scalar(
            f"adapt_{taskname}/{num_tasks_learned}/entropy",
            logit_entropy.item(),
            batch_idx + 1,
        )

        writer.add_scalar(
            f"adapt_{taskname}/{num_tasks_learned}/hop_loss",
            hop_loss.item(),
            batch_idx + 1,
        )

    test_acc = adapt_test(
        model,
        test_loader,
        alphas=None,
    )

    model.apply(lambda m: setattr(m, "alphas", None))
    return test_acc