def _run(self, train_loader, test_loader, optimizer, scheduler, epk):
        for epoch in range(1, epk + 1):
            self._network.train()
            lsc_losses = 0.  # CE loss
            spatial_losses = 0.  # width + height
            flat_losses = 0.  # embedding
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(
                    self._device)
                outputs = self._network(inputs)
                logits = outputs['logits']
                features = outputs['features']
                fmaps = outputs['fmaps']
                # lsc_loss = F.cross_entropy(logits, targets)
                lsc_loss = nca(logits, targets)

                spatial_loss = 0.
                flat_loss = 0.
                if self._old_network is not None:
                    with torch.no_grad():
                        old_outputs = self._old_network(inputs)
                    old_features = old_outputs['features']
                    old_fmaps = old_outputs['fmaps']
                    flat_loss = F.cosine_embedding_loss(
                        features, old_features.detach(),
                        torch.ones(inputs.shape[0]).to(
                            self._device)) * self.factor * lambda_f_base
                    spatial_loss = pod_spatial_loss(
                        fmaps, old_fmaps) * self.factor * lambda_c_base

                loss = lsc_loss + flat_loss + spatial_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # record
                lsc_losses += lsc_loss.item()
                spatial_losses += spatial_loss.item(
                ) if self._cur_task != 0 else spatial_loss
                flat_losses += flat_loss.item(
                ) if self._cur_task != 0 else flat_loss

                # acc
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            if scheduler is not None:
                scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total,
                                  decimals=2)
            test_acc = self._compute_accuracy(self._network, test_loader)
            info1 = 'Task {}, Epoch {}/{} (LR {:.5f}) => '.format(
                self._cur_task, epoch, epk, optimizer.param_groups[0]['lr'])
            info2 = 'LSC_loss {:.2f}, Spatial_loss {:.2f}, Flat_loss {:.2f}, Train_acc {:.2f}, Test_acc {:.2f}'.format(
                lsc_losses / (i + 1), spatial_losses / (i + 1),
                flat_losses / (i + 1), train_acc, test_acc)
            logging.info(info1 + info2)
示例#2
0
    def _extract_vectors(self, loader):
        self._network.eval()
        vectors, targets = [], []
        for _, _inputs, _targets in loader:
            _targets = _targets.numpy()
            if isinstance(self._network, nn.DataParallel):
                _vectors = tensor2numpy(
                    self._network.module.extract_vector(
                        _inputs.to(self._device)))
            else:
                _vectors = tensor2numpy(
                    self._network.extract_vector(_inputs.to(self._device)))

            vectors.append(_vectors)
            targets.append(_targets)

        return np.concatenate(vectors), np.concatenate(targets)
    def _run(self, train_loader, test_loader, epochs_, optimizer, scheduler,
             process):
        prog_bar = tqdm(range(epochs_))
        for _, epoch in enumerate(prog_bar, start=1):
            self._network.train()
            losses = 0.
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(
                    self._device)
                logits = self._network(inputs)

                # CELoss
                clf_loss = F.cross_entropy(logits, targets)

                if self._cur_task == 0:
                    distill_loss = torch.zeros(1, device=self._device)
                else:
                    finetuning_task = (
                        self._cur_task +
                        1) if self._is_finetuning else self._cur_task
                    distill_loss = 0.
                    old_logits = self._old_network(inputs)
                    for i in range(1, finetuning_task + 1):
                        lo = sum(self._seen_classes[:i - 1])
                        hi = sum(self._seen_classes[:i])
                        distill_loss += F.binary_cross_entropy(
                            F.softmax(logits[:, lo:hi] / T, dim=1),
                            F.softmax(old_logits[:, lo:hi] / T, dim=1))

                loss = clf_loss + distill_loss
                losses += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # acc
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            # train_acc = self._compute_accuracy(self._network, train_loader)
            train_acc = np.around(tensor2numpy(correct) * 100 / total,
                                  decimals=2)
            test_acc = self._compute_accuracy(self._network, test_loader)
            info1 = '{} => '.format(process)
            info2 = 'Task {}, Epoch {}/{}, Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
                self._cur_task, epoch + 1, epochs_, losses / len(train_loader),
                train_acc, test_acc)
            prog_bar.set_description(info1 + info2)

        logging.info(info1 + info2)
    def _extract_vectors(self, loader):
        self._network.eval()
        vectors, targets = [], []
        for _, _inputs, _targets in loader:
            _targets = _targets.numpy()
            _vectors = tensor2numpy(self._network.extract_vector(_inputs.to(self._device)))

            vectors.append(_vectors)
            targets.append(_targets)

        return np.concatenate(vectors), np.concatenate(targets)
示例#5
0
    def _train(self, train_loader, test_loader):
        self._network.to(self._device)
        if self._old_network is not None:
            self._old_network.to(self._device)
        optimizer = optim.SGD(self._network.parameters(),
                              lr=lrate,
                              momentum=0.9,
                              weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=milestones,
                                                   gamma=lrate_decay)

        prog_bar = tqdm(range(epochs))
        for _, epoch in enumerate(prog_bar):
            self._network.train()
            losses = 0.
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(
                    self._device)
                logits = self._network(inputs)['logits']
                exp_logits = self.expert(inputs)['logits']
                old_logits = self._old_network(inputs)['logits']

                # Distillation
                dist_term = _KD_loss(logits[:, self._known_classes:],
                                     exp_logits, T1)
                # Retrospection
                retr_term = _KD_loss(logits[:, :self._known_classes],
                                     old_logits, T2)

                loss = dist_term + retr_term
                losses += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # acc
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            # train_acc = self._compute_accuracy(self._network, train_loader)
            train_acc = np.around(tensor2numpy(correct) * 100 / total,
                                  decimals=2)
            test_acc = self._compute_accuracy(self._network, test_loader)
            info = 'Updated CNN => Epoch {}/{}, Loss {:.3f}, Train accy {:.2f}, Test accy {:.2f}'.format(
                epoch + 1, epochs, losses / len(train_loader), train_acc,
                test_acc)
            prog_bar.set_description(info)

        logging.info(info)
示例#6
0
    def _compute_accuracy(self, model, loader):
        model.eval()
        correct, total = 0, 0
        for i, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = model(inputs)['logits']
            predicts = torch.max(outputs, dim=1)[1]
            correct += (predicts.cpu() == targets).sum()
            total += len(targets)

        return np.around(tensor2numpy(correct) * 100 / total, decimals=2)
示例#7
0
    def _update_representation(self, train_loader, test_loader, optimizer,
                               scheduler):
        prog_bar = tqdm(range(epochs))
        for _, epoch in enumerate(prog_bar):
            self._network.train()
            losses = 0.
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(
                    self._device)
                logits = self._network(inputs)['logits']
                onehots = target2onehot(targets, self._total_classes)

                if self._old_network is None:
                    loss = F.binary_cross_entropy_with_logits(logits, onehots)
                else:
                    old_onehots = torch.sigmoid(
                        self._old_network(inputs)['logits'].detach())
                    new_onehots = onehots.clone()
                    new_onehots[:, :self._known_classes] = old_onehots
                    loss = F.binary_cross_entropy_with_logits(
                        logits, new_onehots)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses += loss.item()

                # acc
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            # train_acc = self._compute_accuracy(self._network, train_loader)
            train_acc = np.around(tensor2numpy(correct) * 100 / total,
                                  decimals=2)
            test_acc = self._compute_accuracy(self._network, test_loader)
            info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
                self._cur_task, epoch + 1, epochs, losses / len(train_loader),
                train_acc, test_acc)
            prog_bar.set_description(info)

        logging.info(info)
    def _run(self, train_loader, test_loader, optimizer, scheduler):
        for epoch in range(1, epochs + 1):
            self._network.train()
            ce_losses = 0.
            lf_losses = 0.
            is_losses = 0.
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(
                    self._device)
                outputs = self._network(inputs)
                logits = outputs[
                    'logits']  # Final outputs after scaling  (bs, nb_classes)
                features = outputs[
                    'features']  # Features before fc layer  (bs, 64)
                ce_loss = F.cross_entropy(logits,
                                          targets)  # Cross entropy loss

                lf_loss = 0.  # Less forgetting loss
                is_loss = 0.  # Inter-class speration loss
                if self._old_network is not None:
                    old_outputs = self._old_network(inputs)
                    old_features = old_outputs[
                        'features']  # Features before fc layer
                    lf_loss = F.cosine_embedding_loss(
                        features, old_features.detach(),
                        torch.ones(inputs.shape[0]).to(
                            self._device)) * self.lamda

                    scores = outputs[
                        'new_scores']  # Scores before scaling  (bs, nb_new)
                    old_scores = outputs[
                        'old_scores']  # Scores before scaling  (bs, nb_old)
                    old_classes_mask = np.where(
                        tensor2numpy(targets) < self._known_classes)[0]
                    if len(old_classes_mask) != 0:
                        scores = scores[old_classes_mask]  # (n, nb_new)
                        old_scores = old_scores[
                            old_classes_mask]  # (n, nb_old)

                        # Ground truth targets
                        gt_targets = targets[old_classes_mask]  # (n)
                        old_bool_onehot = target2onehot(
                            gt_targets, self._known_classes).type(torch.bool)
                        anchor_positive = torch.masked_select(
                            old_scores, old_bool_onehot)  # (n)
                        anchor_positive = anchor_positive.view(-1, 1).repeat(
                            1, K)  # (n, K)

                        # Top K hard
                        anchor_hard_negative = scores.topk(K,
                                                           dim=1)[0]  # (n, K)

                        is_loss = F.margin_ranking_loss(anchor_positive,
                                                        anchor_hard_negative,
                                                        torch.ones(K).to(
                                                            self._device),
                                                        margin=margin)

                loss = ce_loss + lf_loss + is_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                ce_losses += ce_loss.item()
                lf_losses += lf_loss.item() if self._cur_task != 0 else lf_loss
                is_losses += is_loss.item() if self._cur_task != 0 and len(
                    old_classes_mask) != 0 else is_loss

                # acc
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            # train_acc = self._compute_accuracy(self._network, train_loader)
            train_acc = np.around(tensor2numpy(correct) * 100 / total,
                                  decimals=2)
            test_acc = self._compute_accuracy(self._network, test_loader)
            info1 = 'Task {}, Epoch {}/{} => '.format(self._cur_task, epoch,
                                                      epochs)
            info2 = 'CE_loss {:.3f}, LF_loss {:.3f}, IS_loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
                ce_losses / (i + 1), lf_losses / (i + 1), is_losses / (i + 1),
                train_acc, test_acc)
            logging.info(info1 + info2)
    def _run(self, train_loader, test_loader, optimizer, scheduler):
        for epoch in range(1, epochs + 1):
            self._network.train()
            clf_losses = 0.  # cross entropy
            distill_losses = 0.  # distillation
            attention_losses = 0.  # attention distillation
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(
                    self._device)
                outputs = self._network(inputs)
                logits = outputs['logits']
                optimizer.zero_grad()  # Same effect as nn.Module.zero_grad()
                if self._old_network is None:
                    clf_loss = F.cross_entropy(logits, targets)
                    clf_losses += clf_loss.item()
                    loss = clf_loss
                else:
                    self._old_network.zero_grad()
                    old_outputs = self._old_network(inputs)
                    old_logits = old_outputs['logits']

                    # Classification loss
                    # if no old samples saved, only calculate loss for new logits
                    clf_loss = F.cross_entropy(logits[:, self._known_classes:],
                                               targets - self._known_classes)
                    clf_losses += clf_loss.item()

                    # Distillation loss
                    # if no old samples saved, only calculate distillation loss for old logits
                    '''
                    distill_loss = F.binary_cross_entropy_with_logits(
                        logits[:, :self._known_classes], torch.sigmoid(old_logits.detach())
                    ) * distill_ratio
                    '''
                    distill_loss = _KD_loss(logits[:, :self._known_classes],
                                            old_logits.detach(),
                                            T=2) * distill_ratio
                    distill_losses += distill_loss.item()

                    # Attention distillation loss
                    top_base_indices = logits[:, :self._known_classes].argmax(
                        dim=1)
                    onehot_top_base = target2onehot(
                        top_base_indices, self._known_classes).to(self._device)

                    logits[:, :self._known_classes].backward(
                        gradient=onehot_top_base, retain_graph=True)
                    old_logits.backward(gradient=onehot_top_base)

                    attention_loss = gradcam_distillation(
                        outputs['gradcam_gradients'][0],
                        old_outputs['gradcam_gradients'][0].detach(),
                        outputs['gradcam_activations'][0],
                        old_outputs['gradcam_activations']
                        [0].detach()) * attention_ratio
                    attention_losses += attention_loss.item()

                    # Integration
                    loss = clf_loss + distill_loss + attention_loss

                    self._old_network.zero_grad()
                    self._network.zero_grad()

                optimizer.zero_grad()  # Same effect as nn.Module.zero_grad()
                loss.backward()
                optimizer.step()

                # acc
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            # train_acc = self._compute_accuracy(self._network, train_loader)
            train_acc = np.around(tensor2numpy(correct) * 100 / total,
                                  decimals=2)
            test_acc = self._compute_accuracy(self._network, test_loader)
            info1 = 'Task {}, Epoch {}/{} => clf_loss {:.2f}, '.format(
                self._cur_task, epoch, epochs, clf_losses / (i + 1))
            info2 = 'distill_loss {:.2f}, attention_loss {:.2f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
                distill_losses / (i + 1), attention_losses / (i + 1),
                train_acc, test_acc)
            logging.info(info1 + info2)
    def _run(self, train_loader, test_loader, optimizer, scheduler):
        for epoch in range(1, epochs + 1):
            self._network.train()  # set train mode
            ce_losses = 0.
            lf_losses = 0.
            is_losses = 0.
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(
                    self._device)
                outputs = self._network(inputs)
                logits = outputs[
                    'logits']  # Final outputs after scaling  (bs, nb_classes), |* i.e., befroe probs=softmax(logits)
                features = outputs[
                    'features']  # Features before fc layer  (bs, 64) |* i.e., feature vector from feature extractor(backbone)
                ce_loss = F.cross_entropy(
                    logits, targets
                )  # Cross entropy loss |* cross_entrophy implicityly implement softmax, so its input is logits.

                lf_loss = 0.  # Less forgetting loss
                is_loss = 0.  # Inter-class speration loss, i.e. margin ranking loss. Eq 8.
                if self._old_network is not None:
                    old_outputs = self._old_network(inputs)
                    old_features = old_outputs[
                        'features']  # Features before fc layer
                    lf_loss = F.cosine_embedding_loss(
                        features, old_features.detach(),
                        torch.ones(inputs.shape[0]).to(
                            self._device)) * self.lamda  # Eq 6.

                    scores = outputs[
                        'new_scores']  # Scores before scaling  (bs, nb_new)
                    old_scores = outputs[
                        'old_scores']  # Scores before scaling  (bs, nb_old)
                    '''@Author:defeng
                        24 May 2021 (Monday)
                        see Line 45 here, we know ucir uses CosineincNet and CosineincNet uses (Split)CosineLinearLayer.
                        Line 93 forward function of SplitCosineLinearLayer, "out" times(X) the scaling factor eta while out1/2 doesn't.
                        (CosineLinearLayer does not have the new/old_scores.)
                    '''
                    old_classes_mask = np.where(
                        tensor2numpy(targets) < self._known_classes)[0]
                    if len(old_classes_mask) != 0:
                        scores = scores[old_classes_mask]  # (n, nb_new)
                        old_scores = old_scores[
                            old_classes_mask]  # (n, nb_old)

                        # Ground truth targets
                        gt_targets = targets[old_classes_mask]  # (n)
                        old_bool_onehot = target2onehot(
                            gt_targets, self._known_classes).type(torch.bool)
                        anchor_positive = torch.masked_select(
                            old_scores, old_bool_onehot
                        )  # *(n)*   |* i.e. select GT class correspoding scores.
                        anchor_positive = anchor_positive.view(-1, 1).repeat(
                            1, K
                        )  # *(n, K)*   |* i.e., <\bar{\theta}, \bar(f(x))>
                        '''@Author:defeng
                            torch.repeat is different from numpy.repeat.
                            see for details: https://pytorch.org/docs/stable/tensors.html?highlight=repeat#torch.Tensor.repeat
                        '''

                        # Top K hard
                        anchor_hard_negative = scores.topk(
                            K, dim=1
                        )[0]  # *(n, K)* |* i.e., <\bar{\theta_{k}}, \bar(f(x))>

                        is_loss = F.margin_ranking_loss(anchor_positive,
                                                        anchor_hard_negative,
                                                        torch.ones(K).to(
                                                            self._device),
                                                        margin=margin)
                        '''@Author:defeng
                            here, the params "torch.ones(K).to(self._device)" for margin_ranking_loss follows the params \
                            requirements in pytorch documentation(specifically, ones(K) is the variable y).
                            see for details: https://pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html#torch.nn.MarginRankingLoss
                        '''

                loss = ce_loss + lf_loss + is_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                ce_losses += ce_loss.item()
                lf_losses += lf_loss.item() if self._cur_task != 0 else lf_loss
                is_losses += is_loss.item() if self._cur_task != 0 and len(
                    old_classes_mask) != 0 else is_loss

                # acc(classification)
                _, preds = torch.max(
                    logits, dim=1
                )  # pred is the indexs/location of the max value in dim1.
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            # train_acc = self._compute_accuracy(self._network, train_loader)
            train_acc = np.around(tensor2numpy(correct) * 100 / total,
                                  decimals=2)
            test_acc = self._compute_accuracy(self._network, test_loader)
            info1 = 'Task {}, Epoch {}/{} => '.format(self._cur_task, epoch,
                                                      epochs)
            info2 = 'CE_loss {:.3f}, LF_loss {:.3f}, IS_loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
                ce_losses / (i + 1), lf_losses / (i + 1), is_losses / (i + 1),
                train_acc, test_acc)
            logging.info(info1 + info2)