コード例 #1
0
class PODNet(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
        self._network = CosineIncrementalNet(args['convnet_type'],
                                             pretrained=False,
                                             nb_proxy=nb_proxy)
        self._class_means = None

    def after_task(self):
        # self.save_checkpoint('podnet')
        self._old_network = self._network.copy().freeze()
        self._known_classes = self._total_classes
        logging.info('Exemplar size: {}'.format(self.exemplar_size))

    def incremental_train(self, data_manager):
        self._cur_task += 1
        self._total_classes = self._known_classes + data_manager.get_task_size(
            self._cur_task)
        self.task_size = self._total_classes - self._known_classes
        self._network.update_fc(self._total_classes, self._cur_task)
        logging.info('Learning on {}-{}'.format(self._known_classes,
                                                self._total_classes))

        # Loader
        train_dset = data_manager.get_dataset(np.arange(
            self._known_classes, self._total_classes),
                                              source='train',
                                              mode='train',
                                              appendent=self._get_memory())
        test_dset = data_manager.get_dataset(np.arange(0, self._total_classes),
                                             source='test',
                                             mode='test')
        self.train_loader = DataLoader(train_dset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=num_workers)
        self.test_loader = DataLoader(test_dset,
                                      batch_size=batch_size,
                                      shuffle=False,
                                      num_workers=num_workers)

        # Procedure
        self._train(data_manager, self.train_loader, self.test_loader)
        self.build_rehearsal_memory(data_manager, self.samples_per_class)

    def _train(self, data_manager, train_loader, test_loader):
        '''
        if self._cur_task == 0:
            loaded_dict = torch.load('./podnet_0.pkl')
            self._network.load_state_dict(loaded_dict['model_state_dict'])
            self._network.to(self._device)
            return
        '''
        # Adaptive factor
        # Adaptive lambda = base * factor
        # According to the official code: factor = total_clases / task_size
        # Slightly different from the implementation in UCIR
        # But the effect is negligible
        if self._cur_task == 0:
            self.factor = 0
        else:
            self.factor = math.sqrt(
                self._total_classes /
                (self._total_classes - self._known_classes))
        logging.info('Adaptive factor: {}'.format(self.factor))

        self._network.to(self._device)
        if self._old_network is not None:
            self._old_network.to(self._device)

        # New + exemplars
        # Fix the embedding of old classes
        if self._cur_task == 0:
            network_params = self._network.parameters()
        else:
            ignored_params = list(map(id, self._network.fc.fc1.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self._network.parameters())
            network_params = [{
                'params': base_params,
                'lr': lrate,
                'weight_decay': weight_decay
            }, {
                'params': self._network.fc.fc1.parameters(),
                'lr': 0,
                'weight_decay': 0
            }]
        optimizer = optim.SGD(network_params,
                              lr=lrate,
                              momentum=0.9,
                              weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                         T_max=epochs)
        self._run(train_loader, test_loader, optimizer, scheduler, epochs)

        # Finetune
        if self._cur_task == 0:
            return
        logging.info(
            'Finetune the network (classifier part) with the undersampled dataset!'
        )
        if self._fixed_memory:
            finetune_samples_per_class = self._memory_per_class
            self._construct_exemplar_unified(data_manager,
                                             finetune_samples_per_class)
        else:
            finetune_samples_per_class = self._memory_size // self._known_classes
            self._reduce_exemplar(data_manager, finetune_samples_per_class)
            self._construct_exemplar(data_manager, finetune_samples_per_class)

        finetune_train_dataset = data_manager.get_dataset(
            [], source='train', mode='train', appendent=self._get_memory())
        finetune_train_loader = DataLoader(finetune_train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=num_workers)
        logging.info('The size of finetune dataset: {}'.format(
            len(finetune_train_dataset)))
        # According to the official code repo, only the classifier is fine-tuned. However, it is
        # strange to update the new weights of the classifier in the training procedure (as UCIR does)
        # but update the all weights of the classifier in the finetune procedure.
        # And my results show that only fine-tuning the classifier part does not improve the performance.
        # Thus all parameters except the old weights of the classifier are fine-tuned in my code.
        # Which one to choose?
        # network_params = self._network.fc.parameters()  # All fc weights
        # network_params = self._network.fc.fc2.parameters()  # Only the new weights of fc
        ignored_params = list(map(id, self._network.fc.fc1.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             self._network.parameters())
        network_params = [{
            'params': base_params,
            'lr': ft_lrate,
            'weight_decay': weight_decay
        }, {
            'params': self._network.fc.fc1.parameters(),
            'lr': 0,
            'weight_decay': 0
        }]
        optimizer = optim.SGD(network_params,
                              lr=ft_lrate,
                              momentum=0.9,
                              weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                         T_max=ft_epochs)
        # scheduler = None
        self._run(finetune_train_loader, test_loader, optimizer, scheduler,
                  ft_epochs)

        # Remove the temporary exemplars of new classes
        if self._fixed_memory:
            self._data_memory = self._data_memory[:-self._memory_per_class *
                                                  self.task_size]
            self._targets_memory = self._targets_memory[:-self.
                                                        _memory_per_class *
                                                        self.task_size]
            # Check
            assert len(
                np.setdiff1d(self._targets_memory,
                             np.arange(
                                 0,
                                 self._known_classes))) == 0, 'Exemplar error!'

    def _run(self, train_loader, test_loader, optimizer, scheduler, epk):
        for epoch in range(1, epk + 1):
            self._network.train()
            lsc_losses = 0.  # CE loss
            spatial_losses = 0.  # width + height
            flat_losses = 0.  # embedding
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(
                    self._device)
                outputs = self._network(inputs)
                logits = outputs['logits']
                features = outputs['features']
                fmaps = outputs['fmaps']
                # lsc_loss = F.cross_entropy(logits, targets)
                lsc_loss = nca(logits, targets)

                spatial_loss = 0.
                flat_loss = 0.
                if self._old_network is not None:
                    with torch.no_grad():
                        old_outputs = self._old_network(inputs)
                    old_features = old_outputs['features']
                    old_fmaps = old_outputs['fmaps']
                    flat_loss = F.cosine_embedding_loss(
                        features, old_features.detach(),
                        torch.ones(inputs.shape[0]).to(
                            self._device)) * self.factor * lambda_f_base
                    spatial_loss = pod_spatial_loss(
                        fmaps, old_fmaps) * self.factor * lambda_c_base

                loss = lsc_loss + flat_loss + spatial_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # record
                lsc_losses += lsc_loss.item()
                spatial_losses += spatial_loss.item(
                ) if self._cur_task != 0 else spatial_loss
                flat_losses += flat_loss.item(
                ) if self._cur_task != 0 else flat_loss

                # acc
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            if scheduler is not None:
                scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total,
                                  decimals=2)
            test_acc = self._compute_accuracy(self._network, test_loader)
            info1 = 'Task {}, Epoch {}/{} (LR {:.5f}) => '.format(
                self._cur_task, epoch, epk, optimizer.param_groups[0]['lr'])
            info2 = 'LSC_loss {:.2f}, Spatial_loss {:.2f}, Flat_loss {:.2f}, Train_acc {:.2f}, Test_acc {:.2f}'.format(
                lsc_losses / (i + 1), spatial_losses / (i + 1),
                flat_losses / (i + 1), train_acc, test_acc)
            logging.info(info1 + info2)
コード例 #2
0
class UCIR(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
        self._network = CosineIncrementalNet(args['convnet_type'],
                                             pretrained=False)
        self._class_means = None

    def after_task(self):
        # self.save_checkpoint()
        self._old_network = self._network.copy().freeze(
        )  # from ucir paper, easy to understand.
        self._known_classes = self._total_classes  # _known_classes is the old_classes which the model have trained on.
        logging.info('Exemplar size: {}'.format(self.exemplar_size))

    def incremental_train(self, data_manager):  #external call in trainer.py
        self._cur_task += 1  # initial cur_task=-1. first time call incremental_train, it will add 1 and become zero.
        self._total_classes = self._known_classes + data_manager.get_task_size(
            self._cur_task)  #total-known = task_size
        self._network.update_fc(self._total_classes,
                                self._cur_task)  #update based on new class nb.
        logging.info('Learning on {}-{}'.format(
            self._known_classes,
            self._total_classes))  #total-known = task_size

        # Loader
        train_dset = data_manager.get_dataset(np.arange(
            self._known_classes, self._total_classes),
                                              source='train',
                                              mode='train',
                                              appendent=self._get_memory())
        test_dset = data_manager.get_dataset(np.arange(0, self._total_classes),
                                             source='test',
                                             mode='test')
        self.train_loader = DataLoader(train_dset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=num_workers)
        self.test_loader = DataLoader(test_dset,
                                      batch_size=batch_size,
                                      shuffle=False,
                                      num_workers=num_workers)

        # Procedure
        self._train(self.train_loader, self.test_loader)
        self.build_rehearsal_memory(data_manager,
                                    self.samples_per_class)  # see base.py
        if len(self._multiple_gpus) > 1:
            self._network = self._network.module
            '''@Author:defeng
                the reason call .module: https://www.zhihu.com/question/67726969/answer/511220696
            '''

    def _train(self, train_loader, test_loader):
        '''
        if self._cur_task == 0:
            loaded_dict = torch.load('./dict_0.pkl')
            self._network.load_state_dict(loaded_dict['model_state_dict'])
            self._network.to(self._device)
            return
        '''

        # Adaptive lambda
        # The definition of adaptive lambda in paper and the official code repository is different.
        # Here we use the definition in official code repository.
        if self._cur_task == 0:
            self.lamda = 0
        else:
            self.lamda = lamda_base * math.sqrt(
                self._known_classes /
                (self._total_classes - self._known_classes))
        logging.info('Adaptive lambda: {}'.format(self.lamda))

        # Fix the embedding of old classes
        if self._cur_task == 0:
            network_params = self._network.parameters()
        else:
            ignored_params = list(map(id, self._network.fc.fc1.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self._network.parameters())
            network_params = [{
                'params': base_params,
                'lr': lrate,
                'weight_decay': weight_decay
            }, {
                'params': self._network.fc.fc1.parameters(),
                'lr': 0,
                'weight_decay': 0
            }]  # lr=0 means freeze.
        optimizer = optim.SGD(network_params,
                              lr=lrate,
                              momentum=0.9,
                              weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                                   milestones=milestones,
                                                   gamma=lrate_decay)

        if len(self._multiple_gpus) > 1:
            self._network = nn.DataParallel(self._network, self._multiple_gpus)
        self._network.to(self._device)
        if self._old_network is not None:
            self._old_network.to(self._device)
            '''@Author:defeng
                in base.py, "self._device = args['device'][0]"
                that is, the old model is moved to device[0] default.
            '''

        self._run(train_loader, test_loader, optimizer, scheduler)

    def _run(self, train_loader, test_loader, optimizer, scheduler):
        for epoch in range(1, epochs + 1):
            self._network.train()  # set train mode
            ce_losses = 0.
            lf_losses = 0.
            is_losses = 0.
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(
                    self._device)
                outputs = self._network(inputs)
                logits = outputs[
                    'logits']  # Final outputs after scaling  (bs, nb_classes), |* i.e., befroe probs=softmax(logits)
                features = outputs[
                    'features']  # Features before fc layer  (bs, 64) |* i.e., feature vector from feature extractor(backbone)
                ce_loss = F.cross_entropy(
                    logits, targets
                )  # Cross entropy loss |* cross_entrophy implicityly implement softmax, so its input is logits.

                lf_loss = 0.  # Less forgetting loss
                is_loss = 0.  # Inter-class speration loss, i.e. margin ranking loss. Eq 8.
                if self._old_network is not None:
                    old_outputs = self._old_network(inputs)
                    old_features = old_outputs[
                        'features']  # Features before fc layer
                    lf_loss = F.cosine_embedding_loss(
                        features, old_features.detach(),
                        torch.ones(inputs.shape[0]).to(
                            self._device)) * self.lamda  # Eq 6.

                    scores = outputs[
                        'new_scores']  # Scores before scaling  (bs, nb_new)
                    old_scores = outputs[
                        'old_scores']  # Scores before scaling  (bs, nb_old)
                    '''@Author:defeng
                        24 May 2021 (Monday)
                        see Line 45 here, we know ucir uses CosineincNet and CosineincNet uses (Split)CosineLinearLayer.
                        Line 93 forward function of SplitCosineLinearLayer, "out" times(X) the scaling factor eta while out1/2 doesn't.
                        (CosineLinearLayer does not have the new/old_scores.)
                    '''
                    old_classes_mask = np.where(
                        tensor2numpy(targets) < self._known_classes)[0]
                    if len(old_classes_mask) != 0:
                        scores = scores[old_classes_mask]  # (n, nb_new)
                        old_scores = old_scores[
                            old_classes_mask]  # (n, nb_old)

                        # Ground truth targets
                        gt_targets = targets[old_classes_mask]  # (n)
                        old_bool_onehot = target2onehot(
                            gt_targets, self._known_classes).type(torch.bool)
                        anchor_positive = torch.masked_select(
                            old_scores, old_bool_onehot
                        )  # *(n)*   |* i.e. select GT class correspoding scores.
                        anchor_positive = anchor_positive.view(-1, 1).repeat(
                            1, K
                        )  # *(n, K)*   |* i.e., <\bar{\theta}, \bar(f(x))>
                        '''@Author:defeng
                            torch.repeat is different from numpy.repeat.
                            see for details: https://pytorch.org/docs/stable/tensors.html?highlight=repeat#torch.Tensor.repeat
                        '''

                        # Top K hard
                        anchor_hard_negative = scores.topk(
                            K, dim=1
                        )[0]  # *(n, K)* |* i.e., <\bar{\theta_{k}}, \bar(f(x))>

                        is_loss = F.margin_ranking_loss(anchor_positive,
                                                        anchor_hard_negative,
                                                        torch.ones(K).to(
                                                            self._device),
                                                        margin=margin)
                        '''@Author:defeng
                            here, the params "torch.ones(K).to(self._device)" for margin_ranking_loss follows the params \
                            requirements in pytorch documentation(specifically, ones(K) is the variable y).
                            see for details: https://pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html#torch.nn.MarginRankingLoss
                        '''

                loss = ce_loss + lf_loss + is_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                ce_losses += ce_loss.item()
                lf_losses += lf_loss.item() if self._cur_task != 0 else lf_loss
                is_losses += is_loss.item() if self._cur_task != 0 and len(
                    old_classes_mask) != 0 else is_loss

                # acc(classification)
                _, preds = torch.max(
                    logits, dim=1
                )  # pred is the indexs/location of the max value in dim1.
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            # train_acc = self._compute_accuracy(self._network, train_loader)
            train_acc = np.around(tensor2numpy(correct) * 100 / total,
                                  decimals=2)
            test_acc = self._compute_accuracy(self._network, test_loader)
            info1 = 'Task {}, Epoch {}/{} => '.format(self._cur_task, epoch,
                                                      epochs)
            info2 = 'CE_loss {:.3f}, LF_loss {:.3f}, IS_loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
                ce_losses / (i + 1), lf_losses / (i + 1), is_losses / (i + 1),
                train_acc, test_acc)
            logging.info(info1 + info2)
コード例 #3
0
class UCIR(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
        self._network = CosineIncrementalNet(args['convnet_type'],
                                             pretrained=False)
        self._class_means = None

    def after_task(self):
        # self.save_checkpoint()
        self._old_network = self._network.copy().freeze()
        self._known_classes = self._total_classes
        logging.info('Exemplar size: {}'.format(self.exemplar_size))

    def incremental_train(self, data_manager):
        self._cur_task += 1
        self._total_classes = self._known_classes + data_manager.get_task_size(
            self._cur_task)
        self._network.update_fc(self._total_classes, self._cur_task)
        logging.info('Learning on {}-{}'.format(self._known_classes,
                                                self._total_classes))

        # Loader
        train_dset = data_manager.get_dataset(np.arange(
            self._known_classes, self._total_classes),
                                              source='train',
                                              mode='train',
                                              appendent=self._get_memory())
        test_dset = data_manager.get_dataset(np.arange(0, self._total_classes),
                                             source='test',
                                             mode='test')
        self.train_loader = DataLoader(train_dset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=num_workers)
        self.test_loader = DataLoader(test_dset,
                                      batch_size=batch_size,
                                      shuffle=False,
                                      num_workers=num_workers)

        # Procedure
        self._train(self.train_loader, self.test_loader)
        self.build_rehearsal_memory(data_manager, self.samples_per_class)
        if len(self._multiple_gpus) > 1:
            self._network = self._network.module

    def _train(self, train_loader, test_loader):
        '''
        if self._cur_task == 0:
            loaded_dict = torch.load('./dict_0.pkl')
            self._network.load_state_dict(loaded_dict['model_state_dict'])
            self._network.to(self._device)
            return
        '''

        # Adaptive lambda
        # The definition of adaptive lambda in paper and the official code repository is different.
        # Here we use the definition in official code repository.
        if self._cur_task == 0:
            self.lamda = 0
        else:
            self.lamda = lamda_base * math.sqrt(
                self._known_classes /
                (self._total_classes - self._known_classes))
        logging.info('Adaptive lambda: {}'.format(self.lamda))

        # Fix the embedding of old classes
        if self._cur_task == 0:
            network_params = self._network.parameters()
        else:
            ignored_params = list(map(id, self._network.fc.fc1.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self._network.parameters())
            network_params = [{
                'params': base_params,
                'lr': lrate,
                'weight_decay': weight_decay
            }, {
                'params': self._network.fc.fc1.parameters(),
                'lr': 0,
                'weight_decay': 0
            }]
        optimizer = optim.SGD(network_params,
                              lr=lrate,
                              momentum=0.9,
                              weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                                   milestones=milestones,
                                                   gamma=lrate_decay)

        if len(self._multiple_gpus) > 1:
            self._network = nn.DataParallel(self._network, self._multiple_gpus)
        self._network.to(self._device)
        if self._old_network is not None:
            self._old_network.to(self._device)

        self._run(train_loader, test_loader, optimizer, scheduler)

    def _run(self, train_loader, test_loader, optimizer, scheduler):
        for epoch in range(1, epochs + 1):
            self._network.train()
            ce_losses = 0.
            lf_losses = 0.
            is_losses = 0.
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(
                    self._device)
                outputs = self._network(inputs)
                logits = outputs[
                    'logits']  # Final outputs after scaling  (bs, nb_classes)
                features = outputs[
                    'features']  # Features before fc layer  (bs, 64)
                ce_loss = F.cross_entropy(logits,
                                          targets)  # Cross entropy loss

                lf_loss = 0.  # Less forgetting loss
                is_loss = 0.  # Inter-class speration loss
                if self._old_network is not None:
                    old_outputs = self._old_network(inputs)
                    old_features = old_outputs[
                        'features']  # Features before fc layer
                    lf_loss = F.cosine_embedding_loss(
                        features, old_features.detach(),
                        torch.ones(inputs.shape[0]).to(
                            self._device)) * self.lamda

                    scores = outputs[
                        'new_scores']  # Scores before scaling  (bs, nb_new)
                    old_scores = outputs[
                        'old_scores']  # Scores before scaling  (bs, nb_old)
                    old_classes_mask = np.where(
                        tensor2numpy(targets) < self._known_classes)[0]
                    if len(old_classes_mask) != 0:
                        scores = scores[old_classes_mask]  # (n, nb_new)
                        old_scores = old_scores[
                            old_classes_mask]  # (n, nb_old)

                        # Ground truth targets
                        gt_targets = targets[old_classes_mask]  # (n)
                        old_bool_onehot = target2onehot(
                            gt_targets, self._known_classes).type(torch.bool)
                        anchor_positive = torch.masked_select(
                            old_scores, old_bool_onehot)  # (n)
                        anchor_positive = anchor_positive.view(-1, 1).repeat(
                            1, K)  # (n, K)

                        # Top K hard
                        anchor_hard_negative = scores.topk(K,
                                                           dim=1)[0]  # (n, K)

                        is_loss = F.margin_ranking_loss(anchor_positive,
                                                        anchor_hard_negative,
                                                        torch.ones(K).to(
                                                            self._device),
                                                        margin=margin)

                loss = ce_loss + lf_loss + is_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                ce_losses += ce_loss.item()
                lf_losses += lf_loss.item() if self._cur_task != 0 else lf_loss
                is_losses += is_loss.item() if self._cur_task != 0 and len(
                    old_classes_mask) != 0 else is_loss

                # acc
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            # train_acc = self._compute_accuracy(self._network, train_loader)
            train_acc = np.around(tensor2numpy(correct) * 100 / total,
                                  decimals=2)
            test_acc = self._compute_accuracy(self._network, test_loader)
            info1 = 'Task {}, Epoch {}/{} => '.format(self._cur_task, epoch,
                                                      epochs)
            info2 = 'CE_loss {:.3f}, LF_loss {:.3f}, IS_loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
                ce_losses / (i + 1), lf_losses / (i + 1), is_losses / (i + 1),
                train_acc, test_acc)
            logging.info(info1 + info2)
コード例 #4
0
class PODNet(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
        self._network = CosineIncrementalNet(args['convnet_type'],
                                             pretrained=False,
                                             nb_proxy=10)
        self._class_means = None

    def after_task(self):
        # self.save_checkpoint('podnet')
        self._old_network = self._network.copy().freeze()
        self._known_classes = self._total_classes
        logging.info('Exemplar size: {}'.format(self.exemplar_size))

    def incremental_train(self, data_manager):
        self._cur_task += 1
        self._total_classes = self._known_classes + data_manager.get_task_size(
            self._cur_task)
        self._network.update_fc(self._total_classes, self._cur_task)
        logging.info('Learning on {}-{}'.format(self._known_classes,
                                                self._total_classes))

        # Loader
        train_dset = data_manager.get_dataset(np.arange(
            self._known_classes, self._total_classes),
                                              source='train',
                                              mode='train',
                                              appendent=self._get_memory())
        test_dset = data_manager.get_dataset(np.arange(0, self._total_classes),
                                             source='test',
                                             mode='test')
        self.train_loader = DataLoader(train_dset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=num_workers)
        self.test_loader = DataLoader(test_dset,
                                      batch_size=batch_size,
                                      shuffle=False,
                                      num_workers=num_workers)

        # Procedure
        self._train(self.train_loader, self.test_loader)
        self.build_rehearsal_memory(data_manager, self.samples_per_class)

    def _train(self, train_loader, test_loader):
        '''
        if self._cur_task == 0:
            loaded_dict = torch.load('./podnet_0.pkl')
            self._network.load_state_dict(loaded_dict['model_state_dict'])
            self._network.to(self._device)
            return
        '''
        # Adaptive factor
        # Adaptive lambda = base * factor
        # According to the official code: factor = total_clases / task_size
        # Slightly different from the implementation in UCIR
        # But the effect is negligible
        if self._cur_task == 0:
            self.factor = 0
        else:
            self.factor = math.sqrt(
                self._total_classes /
                (self._total_classes - self._known_classes))
        logging.info('Adaptive factor: {}'.format(self.factor))

        self._network.to(self._device)
        if self._old_network is not None:
            self._old_network.to(self._device)

        # Fix the embedding of old classes
        if self._cur_task == 0:
            network_params = self._network.parameters()
        else:
            ignored_params = list(map(id, self._network.fc.fc1.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self._network.parameters())
            network_params = [{
                'params': base_params,
                'lr': lrate,
                'weight_decay': weight_decay
            }, {
                'params': self._network.fc.fc1.parameters(),
                'lr': 0,
                'weight_decay': 0
            }]
        optimizer = optim.SGD(network_params,
                              lr=lrate,
                              momentum=0.9,
                              weight_decay=weight_decay)
        # scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=lrate_decay)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                         T_max=epochs)

        self._run(train_loader, test_loader, optimizer, scheduler)

    def _run(self, train_loader, test_loader, optimizer, scheduler):
        for epoch in range(1, epochs + 1):
            self._network.train()
            lsc_losses = 0.  # CE loss
            spatial_losses = 0.  # width + height
            flat_losses = 0.  # embedding
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(
                    self._device)
                outputs = self._network(inputs)
                logits = outputs['logits']
                features = outputs['features']
                fmaps = outputs['fmaps']
                lsc_loss = F.cross_entropy(logits, targets)

                spatial_loss = 0.
                flat_loss = 0.
                if self._old_network is not None:
                    with torch.no_grad():
                        old_outputs = self._old_network(inputs)
                    old_features = old_outputs['features']
                    old_fmaps = old_outputs['fmaps']
                    flat_loss = F.cosine_embedding_loss(
                        features, old_features.detach(),
                        torch.ones(inputs.shape[0]).to(
                            self._device)) * self.factor * lambda_f_base
                    spatial_loss = pod_spatial_loss(
                        fmaps, old_fmaps) * self.factor * lambda_c_base

                loss = lsc_loss + flat_loss + spatial_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # record
                lsc_losses += lsc_loss.item()
                spatial_losses += spatial_loss.item(
                ) if self._cur_task != 0 else spatial_loss
                flat_losses += flat_loss.item(
                ) if self._cur_task != 0 else flat_loss

                # acc
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total,
                                  decimals=2)
            test_acc = self._compute_accuracy(self._network, test_loader)
            info1 = 'Task {}, Epoch {}/{} => '.format(self._cur_task, epoch,
                                                      epochs)
            info2 = 'LSC_loss {:.2f}, Spatial_loss {:.2f}, Flat_loss {:.2f}, Train_acc {:.2f}, Test_acc {:.2f}'.format(
                lsc_losses / (i + 1), spatial_losses / (i + 1),
                flat_losses / (i + 1), train_acc, test_acc)
            logging.info(info1 + info2)