Exemple #1
0
class ICarl(ContinualModel):
    NAME = 'icarl'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(ICarl, self).__init__(backbone, loss, args, transform)
        self.dataset = get_dataset(args)

        # Instantiate buffers
        self.buffer = Buffer(self.args.buffer_size, self.device)

        if type(self.dataset.N_CLASSES_PER_TASK) == list:
            nc = int(np.sum(self.dataset.N_CLASSES_PER_TASK))
        else:
            nc = self.dataset.N_CLASSES_PER_TASK * self.dataset.N_TASKS
        self.eye = torch.eye(nc).to(self.device)

        self.class_means = None
        self.old_net = None
        self.current_task = 0

    def forward(self, x):
        if self.class_means is None:
            with torch.no_grad():
                self.compute_class_means()

        feats = self.net.features(x)
        feats = feats.unsqueeze(1)

        pred = (self.class_means.unsqueeze(0) - feats).pow(2).sum(2)
        return -pred

    def observe(self, inputs, labels, not_aug_inputs, logits=None):
        labels = labels.long()
        if not hasattr(self, 'classes_so_far'):
            self.register_buffer('classes_so_far', labels.unique().to('cpu'))
        else:
            self.register_buffer(
                'classes_so_far',
                torch.cat((self.classes_so_far, labels.to('cpu'))).unique())

        self.class_means = None
        if self.current_task > 0:
            with torch.no_grad():
                logits = torch.sigmoid(self.old_net(inputs))
        self.opt.zero_grad()
        loss = self.get_loss(inputs, labels, self.current_task, logits)
        loss.backward()

        self.opt.step()

        return loss.item()

    @staticmethod
    def binary_cross_entropy(pred, y):
        return -(pred.log() * y + (1 - y) * (1 - pred).log()).mean()

    def get_loss(self, inputs: torch.Tensor, labels: torch.Tensor,
                 task_idx: int, logits: torch.Tensor) -> torch.Tensor:
        """
        Computes the loss tensor.
        :param inputs: the images to be fed to the network
        :param labels: the ground-truth labels
        :param task_idx: the task index
        :return: the differentiable loss value
        """
        if type(self.dataset.N_CLASSES_PER_TASK) == list:
            pc = int(np.sum(self.dataset.N_CLASSES_PER_TASK[:task_idx]))
            ac = int(np.sum(self.dataset.N_CLASSES_PER_TASK[:task_idx + 1]))
        else:
            pc = task_idx * self.dataset.N_CLASSES_PER_TASK
            ac = (task_idx + 1) * self.dataset.N_CLASSES_PER_TASK

        outputs = self.net(inputs)[:, :ac]
        if task_idx == 0:
            # Compute loss on the current task
            targets = self.eye[labels][:, :ac]
            loss = F.binary_cross_entropy_with_logits(outputs, targets)
            assert loss >= 0
        else:
            targets = self.eye[labels][:, pc:ac]
            comb_targets = torch.cat((logits[:, :pc], targets), dim=1)
            loss = F.binary_cross_entropy_with_logits(outputs, comb_targets)
            assert loss >= 0

        if self.args.wd_reg:
            loss += self.args.wd_reg * torch.sum(self.net.get_params()**2)

        return loss

    def begin_task(self, dataset):
        denorm = (lambda x: x) if dataset.get_denormalization_transform()\
                                is None else dataset.get_denormalization_transform()
        if self.current_task > 0:
            if self.args.dataset != 'seq-core50':
                dataset.train_loader.dataset.targets = np.concatenate([
                    dataset.train_loader.dataset.targets,
                    self.buffer.labels.cpu().numpy()
                    [:self.buffer.num_seen_examples]
                ])
                if type(dataset.train_loader.dataset.data) == torch.Tensor:
                    dataset.train_loader.dataset.data = torch.cat([
                        dataset.train_loader.dataset.data,
                        torch.stack([
                            denorm(self.buffer.examples[i].type(
                                torch.uint8).cpu())
                            for i in range(self.buffer.num_seen_examples)
                        ]).squeeze(1)
                    ])
                else:
                    dataset.train_loader.dataset.data = np.concatenate([
                        dataset.train_loader.dataset.data,
                        torch.stack([
                            (denorm(self.buffer.examples[i] * 255).type(
                                torch.uint8).cpu())
                            for i in range(self.buffer.num_seen_examples)
                        ]).numpy().swapaxes(1, 3)
                    ])
            else:

                print(
                    torch.stack([(denorm(self.buffer.examples[i]).cpu())
                                 for i in range(self.buffer.num_seen_examples)
                                 ]).numpy().shape)
                dataset.train_loader.dataset.add_more_data(
                    more_targets=self.buffer.labels.cpu().numpy()
                    [:self.buffer.num_seen_examples],
                    more_data=torch.stack([
                        (denorm(self.buffer.examples[i]).cpu())
                        for i in range(self.buffer.num_seen_examples)
                    ]).numpy().swapaxes(1, 3))

    def end_task(self, dataset) -> None:
        self.old_net = deepcopy(self.net.eval())
        self.net.train()
        with torch.no_grad():
            self.fill_buffer(self.buffer, dataset, self.current_task)
        self.current_task += 1
        self.class_means = None

    def compute_class_means(self) -> None:
        """
        Computes a vector representing mean features for each class.
        """
        # This function caches class means
        transform = self.dataset.get_normalization_transform()
        class_means = []
        examples, labels, _ = self.buffer.get_all_data(transform)
        for _y in self.classes_so_far:
            x_buf = torch.stack([
                examples[i] for i in range(0, len(examples))
                if labels[i].cpu() == _y
            ]).to(self.device)

            class_means.append(self.net.features(x_buf).mean(0))
        self.class_means = torch.stack(class_means)

    def fill_buffer(self, mem_buffer: Buffer, dataset, t_idx: int) -> None:
        """
        Adds examples from the current task to the memory buffer
        by means of the herding strategy.
        :param mem_buffer: the memory buffer
        :param dataset: the dataset from which take the examples
        :param t_idx: the task index
        """

        mode = self.net.training
        self.net.eval()
        samples_per_class = mem_buffer.buffer_size // len(self.classes_so_far)

        if t_idx > 0:
            # 1) First, subsample prior classes
            buf_x, buf_y, buf_l = self.buffer.get_all_data()

            mem_buffer.empty()
            for _y in buf_y.unique():
                idx = (buf_y == _y)
                _y_x, _y_y, _y_l = buf_x[idx], buf_y[idx], buf_l[idx]
                mem_buffer.add_data(examples=_y_x[:samples_per_class],
                                    labels=_y_y[:samples_per_class],
                                    logits=_y_l[:samples_per_class])

        # 2) Then, fill with current tasks
        loader = dataset.not_aug_dataloader(self.args, self.args.batch_size)

        # 2.1 Extract all features
        a_x, a_y, a_f, a_l = [], [], [], []
        for x, y, not_norm_x in loader:
            x, y, not_norm_x = (a.to(self.device) for a in [x, y, not_norm_x])
            a_x.append(not_norm_x.to('cpu'))
            a_y.append(y.to('cpu'))

            feats = self.net.features(x)
            a_f.append(feats.cpu())
            a_l.append(torch.sigmoid(self.net.classifier(feats)).cpu())
        a_x, a_y, a_f, a_l = torch.cat(a_x), torch.cat(a_y), torch.cat(
            a_f), torch.cat(a_l)

        # 2.2 Compute class means
        for _y in a_y.unique():
            idx = (a_y == _y)
            _x, _y, _l = a_x[idx], a_y[idx], a_l[idx]
            feats = a_f[idx]
            mean_feat = feats.mean(0, keepdim=True)

            running_sum = torch.zeros_like(mean_feat)
            i = 0
            while i < samples_per_class and i < feats.shape[0]:
                cost = (mean_feat - (feats + running_sum) / (i + 1)).norm(2, 1)

                idx_min = cost.argmin().item()

                mem_buffer.add_data(
                    examples=_x[idx_min:idx_min + 1].to(self.device),
                    labels=_y[idx_min:idx_min + 1].to(self.device),
                    logits=_l[idx_min:idx_min + 1].to(self.device))

                running_sum += feats[idx_min:idx_min + 1]
                feats[idx_min] = feats[idx_min] + 1e6
                i += 1

        assert len(mem_buffer.examples) <= mem_buffer.buffer_size

        self.net.train(mode)
Exemple #2
0
class Fdr(ContinualModel):
    NAME = 'fdr'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Fdr, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.current_task = 0
        self.i = 0
        self.soft = torch.nn.Softmax(dim=1)
        self.logsoft = torch.nn.LogSoftmax(dim=1)

    def end_task(self, dataset):
        self.current_task += 1
        examples_per_task = self.args.buffer_size // self.current_task

        if self.current_task > 1:
            buf_x, buf_log, buf_tl = self.buffer.get_all_data()
            self.buffer.empty()

            for ttl in buf_tl.unique():
                idx = (buf_tl == ttl)
                ex, log, tasklab = buf_x[idx], buf_log[idx], buf_tl[idx]
                first = min(ex.shape[0], examples_per_task)
                self.buffer.add_data(examples=ex[:first],
                                     logits=log[:first],
                                     task_labels=tasklab[:first])
        counter = 0
        with torch.no_grad():
            for i, data in enumerate(dataset.train_loader):
                inputs, labels, not_aug_inputs = data
                inputs = inputs.to(self.device)
                not_aug_inputs = not_aug_inputs.to(self.device)
                outputs = self.net(inputs)
                if examples_per_task - counter < 0:
                    break
                self.buffer.add_data(
                    examples=not_aug_inputs[:(examples_per_task - counter)],
                    logits=outputs.data[:(examples_per_task - counter)],
                    task_labels=(torch.ones(self.args.batch_size) *
                                 (self.current_task - 1))[:(examples_per_task -
                                                            counter)])
                counter += self.args.batch_size

    def observe(self, inputs, labels, not_aug_inputs):
        self.i += 1

        self.opt.zero_grad()
        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)
        loss.backward()
        self.opt.step()
        if not self.buffer.is_empty():
            self.opt.zero_grad()
            buf_inputs, buf_logits, _ = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            buf_outputs = self.net(buf_inputs)
            loss = torch.norm(
                self.soft(buf_outputs) - self.soft(buf_logits), 2, 1).mean()
            assert not torch.isnan(loss)
            loss.backward()
            self.opt.step()

        return loss.item()