Exemple #1
0
class Er(ContinualModel):
    NAME = 'er'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Er, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)

    def observe(self, inputs, labels, not_aug_inputs):

        real_batch_size = inputs.shape[0]

        self.opt.zero_grad()
        if not self.buffer.is_empty():
            buf_inputs, buf_labels = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            inputs = torch.cat((inputs, buf_inputs))
            labels = torch.cat((labels, buf_labels))

        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)
        loss.backward()
        self.opt.step()

        self.buffer.add_data(examples=not_aug_inputs,
                             labels=labels[:real_batch_size])

        return loss.item()
Exemple #2
0
class Der(ContinualModel):
    NAME = 'der'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Der, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)

    def observe(self, inputs, labels, not_aug_inputs):

        self.opt.zero_grad()

        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)

        if not self.buffer.is_empty():
            buf_inputs, buf_logits = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            buf_outputs = self.net(buf_inputs)
            loss += self.args.alpha * F.mse_loss(buf_outputs, buf_logits)

        loss.backward()
        self.opt.step()
        self.buffer.add_data(examples=not_aug_inputs, logits=outputs.data)

        return loss.item()
Exemple #3
0
class Mer(ContinualModel):
    NAME = 'mer'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Mer, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)

    def draw_batches(self, inp, lab):
        batches = []
        for i in range(self.args.batch_num):
            if not self.buffer.is_empty():
                buf_inputs, buf_labels = self.buffer.get_data(
                    self.args.minibatch_size, transform=self.transform)
                inputs = torch.cat((buf_inputs, inp.unsqueeze(0)))
                labels = torch.cat(
                    (buf_labels, torch.tensor([lab]).to(self.device)))
                batches.append((inputs, labels))
            else:
                batches.append(
                    (inp.unsqueeze(0),
                     torch.tensor([lab]).unsqueeze(0).to(self.device)))
        return batches

    def observe(self, inputs, labels, not_aug_inputs):

        batches = self.draw_batches(inputs, labels)
        theta_A0 = self.net.get_params().data.clone()

        for i in range(self.args.batch_num):
            theta_Wi0 = self.net.get_params().data.clone()

            batch_inputs, batch_labels = batches[i]

            # within-batch step
            self.opt.zero_grad()
            outputs = self.net(batch_inputs)
            loss = self.loss(outputs, batch_labels.squeeze(-1))
            loss.backward()
            self.opt.step()

            # within batch reptile meta-update
            new_params = theta_Wi0 + self.args.beta * (self.net.get_params() -
                                                       theta_Wi0)
            self.net.set_params(new_params)

        self.buffer.add_data(examples=not_aug_inputs.unsqueeze(0),
                             labels=labels)

        # across batch reptile meta-update
        new_new_params = theta_A0 + self.args.gamma * (self.net.get_params() -
                                                       theta_A0)
        self.net.set_params(new_new_params)

        return loss.item()
Exemple #4
0
class AGem(ContinualModel):
    NAME = 'agem'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(AGem, self).__init__(backbone, loss, args, transform)

        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.grad_dims = []
        for param in self.parameters():
            self.grad_dims.append(param.data.numel())
        self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.transform = transform if self.args.iba else None

    def end_task(self, dataset):
        samples_per_task = self.args.buffer_size // dataset.N_TASKS
        loader = dataset.not_aug_dataloader(self.args, samples_per_task)
        cur_x, cur_y = next(iter(loader))[:2]
        self.buffer.add_data(examples=cur_x.to(self.device),
                             labels=cur_y.to(self.device))

    def observe(self, inputs, labels, not_aug_inputs):

        self.zero_grad()
        p = self.net.forward(inputs)
        loss = self.loss(p, labels)
        loss.backward()

        if not self.buffer.is_empty():
            store_grad(self.parameters, self.grad_xy, self.grad_dims)

            buf_inputs, buf_labels = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            self.net.zero_grad()
            buf_outputs = self.net.forward(buf_inputs)
            penalty = self.loss(buf_outputs, buf_labels)
            penalty.backward()
            store_grad(self.parameters, self.grad_er, self.grad_dims)

            dot_prod = torch.dot(self.grad_xy, self.grad_er)
            if dot_prod.item() < 0:
                g_tilde = project(gxy=self.grad_xy, ger=self.grad_er)
                overwrite_grad(self.parameters, g_tilde, self.grad_dims)
            else:
                overwrite_grad(self.parameters, self.grad_xy, self.grad_dims)

        self.opt.step()

        return loss.item()
Exemple #5
0
class AGemr(ContinualModel):
    NAME = 'agem_r'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(AGemr, self).__init__(backbone, loss, args, transform)

        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.grad_dims = []
        for param in self.parameters():
            self.grad_dims.append(param.data.numel())
        self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.current_task = 0

    def observe(self, inputs, labels, not_aug_inputs):
        self.zero_grad()
        p = self.net.forward(inputs)
        loss = self.loss(p, labels)
        loss.backward()

        if not self.buffer.is_empty():
            store_grad(self.parameters, self.grad_xy, self.grad_dims)

            buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size)
            self.net.zero_grad()
            buf_outputs = self.net.forward(buf_inputs)
            penalty = self.loss(buf_outputs, buf_labels)
            penalty.backward()
            store_grad(self.parameters, self.grad_er, self.grad_dims)

            dot_prod = torch.dot(self.grad_xy, self.grad_er)
            if dot_prod.item() < 0:
                g_tilde = project(gxy=self.grad_xy, ger=self.grad_er)
                overwrite_grad(self.parameters, g_tilde, self.grad_dims)
            else:
                overwrite_grad(self.parameters, self.grad_xy, self.grad_dims)

        self.opt.step()

        self.buffer.add_data(examples=not_aug_inputs, labels=labels)

        return loss.item()
Exemple #6
0
class Gem(ContinualModel):
    NAME = 'gem'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(Gem, self).__init__(backbone, loss, args, transform)
        self.current_task = 0
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.transform = transform

        # Allocate temporary synaptic memory
        self.grad_dims = []
        for pp in self.parameters():
            self.grad_dims.append(pp.data.numel())

        self.grads_cs = []
        self.grads_da = torch.zeros(np.sum(self.grad_dims)).to(self.device)
        self.transform = transform if self.args.iba else None

    def end_task(self, dataset):
        self.current_task += 1
        self.grads_cs.append(
            torch.zeros(np.sum(self.grad_dims)).to(self.device))

        # add data to the buffer
        samples_per_task = self.args.buffer_size // dataset.N_TASKS

        loader = dataset.not_aug_dataloader(self.args, samples_per_task)
        cur_x, cur_y = next(iter(loader))[:2]
        self.buffer.add_data(
            examples=cur_x.to(self.device),
            labels=cur_y.to(self.device),
            task_labels=torch.ones(samples_per_task, dtype=torch.long).to(
                self.device) * (self.current_task - 1))

    def observe(self, inputs, labels, not_aug_inputs):

        if not self.buffer.is_empty():
            buf_inputs, buf_labels, buf_task_labels = self.buffer.get_data(
                self.args.buffer_size, transform=self.transform)

            for tt in buf_task_labels.unique():
                # compute gradient on the memory buffer
                self.opt.zero_grad()
                cur_task_inputs = buf_inputs[buf_task_labels == tt]
                cur_task_labels = buf_labels[buf_task_labels == tt]

                for i in range(
                        math.ceil(len(cur_task_inputs) /
                                  self.args.batch_size)):
                    cur_task_outputs = self.forward(
                        cur_task_inputs[i * self.args.batch_size:(i + 1) *
                                        self.args.batch_size])
                    penalty = self.loss(
                        cur_task_outputs,
                        cur_task_labels[i * self.args.batch_size:(i + 1) *
                                        self.args.batch_size],
                        reduction='sum') / cur_task_inputs.shape[0]
                    penalty.backward()
                store_grad(self.parameters, self.grads_cs[tt], self.grad_dims)

                # cur_task_outputs = self.forward(cur_task_inputs)
                # penalty = self.loss(cur_task_outputs, cur_task_labels)
                # penalty.backward()
                # store_grad(self.parameters, self.grads_cs[tt], self.grad_dims)

        # now compute the grad on the current data
        self.opt.zero_grad()
        outputs = self.forward(inputs)
        loss = self.loss(outputs, labels)
        loss.backward()

        # check if gradient violates buffer constraints
        if not self.buffer.is_empty():
            # copy gradient
            store_grad(self.parameters, self.grads_da, self.grad_dims)

            dot_prod = torch.mm(self.grads_da.unsqueeze(0),
                                torch.stack(self.grads_cs).T)
            if (dot_prod < 0).sum() != 0:
                project2cone2(self.grads_da.unsqueeze(1),
                              torch.stack(self.grads_cs).T,
                              margin=self.args.gamma)
                # copy gradients back
                overwrite_grad(self.parameters, self.grads_da, self.grad_dims)

        self.opt.step()

        return loss.item()
Exemple #7
0
class HAL(ContinualModel):
    NAME = 'hal'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(HAL, self).__init__(backbone, loss, args, transform)
        self.task_number = 0
        self.buffer = Buffer(self.args.buffer_size,
                             self.device,
                             get_dataset(args).N_TASKS,
                             mode='ring')
        self.hal_lambda = args.hal_lambda
        self.beta = args.beta
        self.gamma = args.gamma
        self.anchor_optimization_steps = 100
        self.finetuning_epochs = 1
        self.dataset = get_dataset(args)
        self.spare_model = self.dataset.get_backbone()
        self.spare_model.to(self.device)
        self.spare_opt = SGD(self.spare_model.parameters(), lr=self.args.lr)

    def end_task(self, dataset):
        self.task_number += 1
        # ring buffer mgmt (if we are not loading
        if self.task_number > self.buffer.task_number:
            self.buffer.num_seen_examples = 0
            self.buffer.task_number = self.task_number
        # get anchors (provided that we are not loading the model
        if len(self.anchors) < self.task_number * dataset.N_CLASSES_PER_TASK:
            self.get_anchors(dataset)
            del self.phi

    def get_anchors(self, dataset):
        theta_t = self.net.get_params().detach().clone()
        self.spare_model.set_params(theta_t)

        # fine tune on memory buffer
        for _ in range(self.finetuning_epochs):
            inputs, labels = self.buffer.get_data(self.args.batch_size,
                                                  transform=self.transform)
            self.spare_opt.zero_grad()
            out = self.spare_model(inputs)
            loss = self.loss(out, labels)
            loss.backward()
            self.spare_opt.step()

        theta_m = self.spare_model.get_params().detach().clone()

        classes_for_this_task = np.unique(dataset.train_loader.dataset.targets)

        for a_class in classes_for_this_task:
            e_t = torch.rand(self.input_shape,
                             requires_grad=True,
                             device=self.device)
            e_t_opt = SGD([e_t], lr=self.args.lr)
            print(file=sys.stderr)
            for i in range(self.anchor_optimization_steps):
                e_t_opt.zero_grad()
                cum_loss = 0

                self.spare_opt.zero_grad()
                self.spare_model.set_params(theta_m.detach().clone())
                loss = -torch.sum(
                    self.loss(self.spare_model(e_t.unsqueeze(0)),
                              torch.tensor([a_class]).to(self.device)))
                loss.backward()
                cum_loss += loss.item()

                self.spare_opt.zero_grad()
                self.spare_model.set_params(theta_t.detach().clone())
                loss = torch.sum(
                    self.loss(self.spare_model(e_t.unsqueeze(0)),
                              torch.tensor([a_class]).to(self.device)))
                loss.backward()
                cum_loss += loss.item()

                self.spare_opt.zero_grad()
                loss = torch.sum(self.gamma *
                                 (self.spare_model.features(e_t.unsqueeze(0)) -
                                  self.phi)**2)
                assert not self.phi.requires_grad
                loss.backward()
                cum_loss += loss.item()

                e_t_opt.step()

            e_t = e_t.detach()
            e_t.requires_grad = False
            self.anchors = torch.cat((self.anchors, e_t.unsqueeze(0)))
            del e_t
            print('Total anchors:', len(self.anchors), file=sys.stderr)

        self.spare_model.zero_grad()

    def observe(self, inputs, labels, not_aug_inputs):
        real_batch_size = inputs.shape[0]
        if not hasattr(self, 'input_shape'):
            self.input_shape = inputs.shape[1:]
        if not hasattr(self, 'anchors'):
            self.anchors = torch.zeros(tuple([0] + list(self.input_shape))).to(
                self.device)
        if not hasattr(self, 'phi'):
            print('Building phi', file=sys.stderr)
            with torch.no_grad():
                self.phi = torch.zeros_like(self.net.features(
                    inputs[0].unsqueeze(0)),
                                            requires_grad=False)
            assert not self.phi.requires_grad

        if not self.buffer.is_empty():
            buf_inputs, buf_labels = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            inputs = torch.cat((inputs, buf_inputs))
            labels = torch.cat((labels, buf_labels))

        old_weights = self.net.get_params().detach().clone()

        self.opt.zero_grad()
        outputs = self.net(inputs)

        k = self.task_number

        loss = self.loss(outputs, labels)
        loss.backward()
        self.opt.step()

        first_loss = 0

        assert len(self.anchors) == self.dataset.N_CLASSES_PER_TASK * k

        if len(self.anchors) > 0:
            first_loss = loss.item()
            with torch.no_grad():
                pred_anchors = self.net(self.anchors)

            self.net.set_params(old_weights)
            pred_anchors -= self.net(self.anchors)
            loss = self.hal_lambda * (pred_anchors**2).mean()
            loss.backward()
            self.opt.step()

        with torch.no_grad():
            self.phi = self.beta * self.phi + (
                1 - self.beta) * self.net.features(
                    inputs[:real_batch_size]).mean(0)

        self.buffer.add_data(examples=not_aug_inputs,
                             labels=labels[:real_batch_size])

        return first_loss + loss.item()
Exemple #8
0
    def fill_buffer(self, mem_buffer: Buffer, dataset, t_idx: int) -> None:
        """
        Adds examples from the current task to the memory buffer
        by means of the herding strategy.
        :param mem_buffer: the memory buffer
        :param dataset: the dataset from which take the examples
        :param t_idx: the task index
        """

        mode = self.net.training
        self.net.eval()
        samples_per_class = mem_buffer.buffer_size // len(self.classes_so_far)

        if t_idx > 0:
            # 1) First, subsample prior classes
            buf_x, buf_y, buf_l = self.buffer.get_all_data()

            mem_buffer.empty()
            for _y in buf_y.unique():
                idx = (buf_y == _y)
                _y_x, _y_y, _y_l = buf_x[idx], buf_y[idx], buf_l[idx]
                mem_buffer.add_data(examples=_y_x[:samples_per_class],
                                    labels=_y_y[:samples_per_class],
                                    logits=_y_l[:samples_per_class])

        # 2) Then, fill with current tasks
        loader = dataset.not_aug_dataloader(self.args, self.args.batch_size)

        # 2.1 Extract all features
        a_x, a_y, a_f, a_l = [], [], [], []
        for x, y, not_norm_x in loader:
            x, y, not_norm_x = (a.to(self.device) for a in [x, y, not_norm_x])
            a_x.append(not_norm_x.to('cpu'))
            a_y.append(y.to('cpu'))

            feats = self.net.features(x)
            a_f.append(feats.cpu())
            a_l.append(torch.sigmoid(self.net.classifier(feats)).cpu())
        a_x, a_y, a_f, a_l = torch.cat(a_x), torch.cat(a_y), torch.cat(
            a_f), torch.cat(a_l)

        # 2.2 Compute class means
        for _y in a_y.unique():
            idx = (a_y == _y)
            _x, _y, _l = a_x[idx], a_y[idx], a_l[idx]
            feats = a_f[idx]
            mean_feat = feats.mean(0, keepdim=True)

            running_sum = torch.zeros_like(mean_feat)
            i = 0
            while i < samples_per_class and i < feats.shape[0]:
                cost = (mean_feat - (feats + running_sum) / (i + 1)).norm(2, 1)

                idx_min = cost.argmin().item()

                mem_buffer.add_data(
                    examples=_x[idx_min:idx_min + 1].to(self.device),
                    labels=_y[idx_min:idx_min + 1].to(self.device),
                    logits=_l[idx_min:idx_min + 1].to(self.device))

                running_sum += feats[idx_min:idx_min + 1]
                feats[idx_min] = feats[idx_min] + 1e6
                i += 1

        assert len(mem_buffer.examples) <= mem_buffer.buffer_size

        self.net.train(mode)
Exemple #9
0
class Fdr(ContinualModel):
    NAME = 'fdr'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Fdr, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.current_task = 0
        self.i = 0
        self.soft = torch.nn.Softmax(dim=1)
        self.logsoft = torch.nn.LogSoftmax(dim=1)

    def end_task(self, dataset):
        self.current_task += 1
        examples_per_task = self.args.buffer_size // self.current_task

        if self.current_task > 1:
            buf_x, buf_log, buf_tl = self.buffer.get_all_data()
            self.buffer.empty()

            for ttl in buf_tl.unique():
                idx = (buf_tl == ttl)
                ex, log, tasklab = buf_x[idx], buf_log[idx], buf_tl[idx]
                first = min(ex.shape[0], examples_per_task)
                self.buffer.add_data(examples=ex[:first],
                                     logits=log[:first],
                                     task_labels=tasklab[:first])
        counter = 0
        with torch.no_grad():
            for i, data in enumerate(dataset.train_loader):
                inputs, labels, not_aug_inputs = data
                inputs = inputs.to(self.device)
                not_aug_inputs = not_aug_inputs.to(self.device)
                outputs = self.net(inputs)
                if examples_per_task - counter < 0:
                    break
                self.buffer.add_data(
                    examples=not_aug_inputs[:(examples_per_task - counter)],
                    logits=outputs.data[:(examples_per_task - counter)],
                    task_labels=(torch.ones(self.args.batch_size) *
                                 (self.current_task - 1))[:(examples_per_task -
                                                            counter)])
                counter += self.args.batch_size

    def observe(self, inputs, labels, not_aug_inputs):
        self.i += 1

        self.opt.zero_grad()
        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)
        loss.backward()
        self.opt.step()
        if not self.buffer.is_empty():
            self.opt.zero_grad()
            buf_inputs, buf_logits, _ = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            buf_outputs = self.net(buf_inputs)
            loss = torch.norm(
                self.soft(buf_outputs) - self.soft(buf_logits), 2, 1).mean()
            assert not torch.isnan(loss)
            loss.backward()
            self.opt.step()

        return loss.item()
Exemple #10
0
def train_cl(train_set, test_set, model, loss, optimizer, device, config):
    """
    :param train_set: Train set
    :param test_set: Test set
    :param model: PyTorch model
    :param loss: loss function
    :param optimizer: optimizer
    :param device: device cuda/cpu
    :param config: configuration
    """

    name = ""
    # global_writer = SummaryWriter('./runs/continual/train/global/' + datetime.datetime.now().strftime('%m_%d_%H_%M'))
    global_writer = SummaryWriter('./runs/continual/train/global/' + name)
    buffer = Buffer(config['buffer_size'], device)
    accuracy = []
    text = open("result_" + name + ".txt",
                "w")  # TODO save results in a .txt file

    # Eval without training
    random_accuracy = evaluate_past(model,
                                    len(test_set) - 1, test_set, loss, device)
    text.write("Evaluation before training" + '\n')
    for a in random_accuracy:
        text.write(f"{a:.2f}% ")
    text.write('\n')

    for index, data_set in enumerate(train_set):
        model.train()
        print(f"----- DOMAIN {index} -----")
        print("Training model...")
        train_loader = DataLoader(data_set,
                                  batch_size=config['batch_size'],
                                  shuffle=False)

        for epoch in tqdm(range(config['epochs'])):
            epoch_loss = []
            epoch_acc = []
            for i, (x, y) in enumerate(train_loader):
                optimizer.zero_grad()

                inputs = x.to(device)
                labels = y.to(device)
                if not buffer.is_empty():
                    # Strategy 50/50
                    # From batch of 64 (dataloader) to 64 + 64 (dataloader + replay)
                    buf_input, buf_label = buffer.get_data(
                        config['batch_size'])
                    inputs = torch.cat((inputs, torch.stack(buf_input)))
                    labels = torch.cat((labels, torch.stack(buf_label)))

                y_pred = model(inputs)
                s_loss = loss(y_pred.squeeze(1), labels)
                acc = binary_accuracy(y_pred.squeeze(1), labels)
                # METRICHE INTERNE EPOCA
                epoch_loss.append(s_loss.item())
                epoch_acc.append(acc.item())

                s_loss.backward()
                optimizer.step()

                if epoch == 0:
                    buffer.add_data(examples=x.to(device), labels=y.to(device))

            global_writer.add_scalar('Train_global/Loss',
                                     statistics.mean(epoch_loss),
                                     epoch + (config['epochs'] * index))
            global_writer.add_scalar('Train_global/Accuracy',
                                     statistics.mean(epoch_acc),
                                     epoch + (config['epochs'] * index))

            # domain_writer.add_scalar(f'Train_D{index}/Loss', statistics.mean(epoch_loss), epoch)
            # domain_writer.add_scalar(f'Train_D{index}/Accuracy', statistics.mean(epoch_acc), epoch)

            if epoch % 100 == 0:
                print(
                    f'\nEpoch {epoch:03}/{config["epochs"]} | Loss: {statistics.mean(epoch_loss):.5f} '
                    f'| Acc: {statistics.mean(epoch_acc):.5f}')

            # Last epoch (only for stats)
            if epoch == 499:
                print(
                    f'\nEpoch {epoch:03}/{config["epochs"]} | Loss: {statistics.mean(epoch_loss):.5f} '
                    f'| Acc: {statistics.mean(epoch_acc):.5f}')

        # Test on domain just trained + old domains
        evaluation = evaluate_past(model, index, test_set, loss, device)
        accuracy.append(evaluation)
        text.write(f"Evaluation after domain {index}" + '\n')
        for a in evaluation:
            text.write(f"{a:.2f}% ")
        text.write('\n')

        if index != len(train_set) - 1:
            accuracy[index].append(
                evaluate_next(model, index, test_set, loss, device))

    # Check buffer distribution
    buffer.check_distribution()

    # Compute transfer metrics
    backward = backward_transfer(accuracy)
    forward = forward_transfer(accuracy, random_accuracy)
    forget = forgetting(accuracy)
    print(f'Backward transfer: {backward}')  # todo Sono in %?
    print(f'Forward transfer: {forward}')
    print(f'Forgetting: {forget}')

    text.write(f"Backward: {backward}\n")
    text.write(f"Forward: {forward}\n")
    text.write(f"Forgetting: {forget}\n")
    text.close()
Exemple #11
0
class OCILFAST(ContinualModel):
    NAME = 'OCILFAST'
    COMPATIBILITY = ['class-il', 'task-il']

    def __init__(self, net, loss, args, transform):
        super(OCILFAST, self).__init__(net, loss, args, transform)

        self.nets = []
        self.c = []
        self.threshold = []

        self.nu = self.args.nu
        self.eta = self.args.eta
        self.eps = self.args.eps
        self.embedding_dim = self.args.embedding_dim
        self.weight_decay = self.args.weight_decay
        self.margin = self.args.margin

        self.current_task = 0
        self.cpt = None
        self.nc = None
        self.eye = None
        self.buffer_size = self.args.buffer_size
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.nf = self.args.nf

        if self.args.dataset == 'seq-cifar10' or self.args.dataset == 'seq-mnist':
            self.input_offset = -0.5
        elif self.args.dataset == 'seq-tinyimg':
            self.input_offset = 0
        else:
            self.input_offset = 0

    # 任务初始化
    def begin_task(self, dataset):

        if self.cpt is None:
            self.cpt = dataset.N_CLASSES_PER_TASK
            self.nc = dataset.N_TASKS * self.cpt
            self.eye = torch.tril(torch.ones((self.nc, self.nc))).bool().to(
                self.device)  # 下三角包括对角线为True,上三角为False,用于掩码

        if len(self.nets) == 0:
            for i in range(self.nc):
                self.nets.append(
                    get_backbone(self.net, self.embedding_dim, self.nc,
                                 self.nf).to(self.device))
                self.c.append(
                    torch.ones(self.embedding_dim, device=self.device))

        self.current_task += 1

    def train_model(self, dataset, train_loader):

        categories = list(
            range((self.current_task - 1) * self.cpt,
                  (self.current_task) * self.cpt))
        print('==========\t task: %d\t categories:' % self.current_task,
              categories, '\t==========')
        if self.args.print_file:
            print('==========\t task: %d\t categories:' % self.current_task,
                  categories,
                  '\t==========',
                  file=self.args.print_file)

        for category in categories:
            losses = []

            if category > 0:
                self.reset_train_loader(train_loader, category)

            for epoch in range(self.args.n_epochs):

                avg_loss, maxloss, posdist, negdist, gloloss = self.train_category(
                    train_loader, category, epoch)

                losses.append(avg_loss)
                if epoch == 0 or (epoch + 1) % 5 == 0:
                    print("epoch: %d\t task: %d \t category: %d \t loss: %f" %
                          (epoch + 1, self.current_task, category, avg_loss))

                    if self.args.print_file:
                        print(
                            "epoch: %d\t task: %d \t category: %d \t loss: %f"
                            %
                            (epoch + 1, self.current_task, category, avg_loss),
                            file=self.args.print_file)

                    plt.figure(figsize=(20, 12))

                    ax = plt.subplot(2, 2, 1)
                    ax.set_title('maxloss')
                    plt.xlim((0, 2))
                    if maxloss is not None:
                        try:
                            sns.distplot(maxloss)
                        except:
                            pass

                    ax = plt.subplot(2, 2, 2)
                    ax.set_title('posdist')
                    plt.xlim((0, 2))
                    try:
                        sns.distplot(posdist)
                    except:
                        print(posdist)

                    ax = plt.subplot(2, 2, 3)
                    ax.set_title('negdist')
                    plt.xlim((0, 2))
                    try:
                        sns.distplot(negdist)
                    except:
                        print(negdist)

                    ax = plt.subplot(2, 2, 4)
                    ax.set_title('gloloss')
                    plt.xlim((0, 2))
                    try:
                        sns.distplot(gloloss)
                    except:
                        print(gloloss)

                    plt.savefig("../" + self.args.img_dir +
                                "/loss-cat%d-epoch%d.png" % (category, epoch))
                    plt.clf()

            x = list(range(len(losses)))
            plt.plot(x, losses)
            plt.savefig("../" + self.args.img_dir + "/loss-cat%d.png" %
                        (category))
            plt.clf()

        self.fill_buffer(train_loader)

    def reset_train_loader(self, train_loader, category):

        dataset = train_loader.dataset
        input = dataset.data
        loader = DataLoader(dataset,
                            batch_size=self.args.batch_size,
                            shuffle=False)

        inputs = []
        targets = []
        prev_dists = []

        prev_categories = list(range(category))
        print('prev_categories', prev_categories)
        if self.args.print_file:
            print('prev_categories',
                  prev_categories,
                  file=self.args.print_file)
        for i, data in enumerate(loader):
            input, target, _ = data
            _, prev_dist = self.predict(input, prev_categories)

            inputs.append(input.detach().cpu())
            targets.append(target.detach().cpu())
            prev_dists.append(prev_dist.detach().cpu())

        inputs = torch.cat(inputs, dim=0)
        targets = torch.cat(targets, dim=0)
        prev_dists = torch.cat(prev_dists, dim=0)
        dataset.set_prevdist(prev_dists)

    def train_category(self, data_loader, category: int, epoch_id):

        self.init_center_c(data_loader, category)
        c = self.c[category]

        network = self.nets[category].to(self.device)
        network.train()

        optimizer = SGD(network.parameters(),
                        lr=self.args.lr,
                        weight_decay=self.weight_decay)
        avg_loss = 0.0
        sample_num = 0

        maxloss = []
        posdist = []
        negdist = []
        gloloss = []

        prev_categories = list(range(category))
        for i, data in enumerate(data_loader):
            inputs, semi_targets, prev_dists = data
            inputs = inputs.to(self.device)
            semi_targets = semi_targets.to(self.device)
            prev_dists = prev_dists.to(self.device)

            if (not self.buffer.is_empty()) and self.args.buffer_size > 0:
                buf_inputs, buf_labels = self.buffer.get_data(
                    self.args.minibatch_size, transform=self.transform)
                # print(buf_inputs[0])
                inputs = torch.cat((inputs, buf_inputs))
                semi_targets = torch.cat((semi_targets, buf_labels))

            # Zero the network parameter gradients
            optimizer.zero_grad()

            # 注意网络的输入要减去0.5
            outputs = network(inputs + self.input_offset)

            dists = torch.sum((outputs - c)**2, dim=1)
            pos_dist_loss = torch.relu(dists - self.args.r)

            if category > 0:
                max_scores = torch.relu(dists.view(-1, 1) - prev_dists)
                max_loss = torch.sum(max_scores,
                                     dim=1) * self.margin / category

                loss_pos = pos_dist_loss + max_loss
                loss_neg = self.eta * dists**-1

                pos_max_loss = max_loss[semi_targets == category]
                maxloss.append(pos_max_loss.detach().cpu().data.numpy())

            else:
                loss_pos = pos_dist_loss
                loss_neg = self.eta * dists**-1

            losses = torch.where(semi_targets == category, loss_pos, loss_neg)
            gloloss.append(losses.detach().cpu().data.numpy())
            loss = torch.mean(losses)

            loss.backward()
            optimizer.step()

            # 记录损失部分
            pos_dist = pos_dist_loss[semi_targets == category]
            posdist.append(pos_dist.detach().cpu().data.numpy())

            neg_dist = loss_neg[semi_targets != category]
            negdist.append(neg_dist.detach().cpu().data.numpy())

            avg_loss += loss.item()
            sample_num += inputs.shape[0]

            # 旧类别只训练一次
            if category < (self.current_task - 1) * self.cpt:
                break

        avg_loss /= sample_num
        if len(maxloss) > 0:
            maxloss = np.hstack(maxloss)
        else:
            maxloss = None
        posdist = np.hstack(posdist)
        negdist = np.hstack(negdist)
        gloloss = np.hstack(gloloss)
        return avg_loss, maxloss, posdist, negdist, gloloss

    def fill_buffer(self, train_loader):
        for data in train_loader:
            # get the inputs of the batch
            inputs, semi_targets, not_aug_inputs = data
            self.buffer.add_data(examples=not_aug_inputs, labels=semi_targets)

    def init_center_c(self, train_loader: DataLoader, category):
        """Initialize hypersphere center c as the mean from an initial forward pass on the data."""
        n_samples = 0
        c = 0

        net = self.nets[category].to(self.device)

        net.eval()
        with torch.no_grad():
            for data in train_loader:
                # get the inputs of the batch
                inputs, semi_targets, not_aug_inputs = data
                inputs = inputs.to(self.device)
                semi_targets = semi_targets.to(self.device)
                outputs = net(inputs + self.input_offset)
                outputs = outputs[semi_targets == category]  # 取所有正样本来进行圆心初始化
                # print(outputs)
                n_samples += outputs.shape[0]
                c += torch.sum(outputs, dim=0)

        c /= n_samples

        # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights.
        c[(abs(c) < self.eps) & (c < 0)] = -self.eps
        c[(abs(c) < self.eps) & (c > 0)] = self.eps
        self.c[category] = c.to(self.device)

    def get_score(self, dist, category):
        score = 1 / (dist + 1e-6)

        return score

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        categories = list(range(self.current_task * self.cpt))
        return self.predict(x, categories)[0]

    def predict(self, inputs: torch.Tensor, categories):
        inputs = inputs.to(self.device)
        outcome, dists = [], []
        with torch.no_grad():
            for i in categories:
                net = self.nets[i]
                net.to(self.device)
                net.eval()

                c = self.c[i].to(self.device)

                pred = net(inputs + self.input_offset)
                dist = torch.sum((pred - c)**2, dim=1)

                scores = self.get_score(dist, i)

                outcome.append(scores.view(-1, 1))
                dists.append(dist.view(-1, 1))

        outcome = torch.cat(outcome, dim=1)
        dists = torch.cat(dists, dim=1)
        return outcome, dists