Beispiel #1
0
class IncrementalLearning:

    def __init__(self):
        self.args = ArgumentManager().get_args(parser_type='incremental')

        self.dsae = True

        # === Data
        self.dataset = self.args.dataset
        if self.dataset == 'miniImageNet' or 'cnbc-face' or 'cub-200':
            self.decoder_output_size = 6 if self.dsae else 84
            self.decoder_scale_factor = 3
        elif self.dataset == 'cifar100':
            self.decoder_output_size = 4 if self.dsae else 32
            self.decoder_scale_factor = 2
        elif self.dataset == 'casia-face':
            self.decoder_output_size = 8 if self.dsae else 128
            self.decoder_scale_factor = 3
        else:
            raise NotImplementedError
        torch.manual_seed(self.args.seed)
        print('==============\nArgs:{}\n=============='.format(self.args))

        if USE_GPU:
            print('Currently using GPU: {}-{}'.format(
                torch.cuda.current_device(), torch.cuda.get_device_name()))
            cudnn.benchmark = True
            torch.cuda.manual_seed_all(self.args.seed)
        else:
            print('Currently using CPU (GPU is highly recommended)')

        # === Encoder & Decoder
        self.encoder = create_encoder(self.args, use_avgpool=True, is_snail=False)
        print('Encoder:', self.encoder)

        if self.dsae:
            self.decoder = create_decoder(self.args, out_dim=256, fm_level=3)
        else:
            self.decoder = create_decoder(self.args, out_dim=3, fm_level=-1)
        print('Decoder:', self.decoder)

        if self.args.load:  # Loading pre-trained checkpoint
            ckp_path = osp.join(self.args.load_dir, self.dataset, 'pretrain', self.args.encoder, 'best_model.ckp')
            print('Loading checkpoint from {}'.format(ckp_path))
            ckp = torch.load(ckp_path)
            encoder_state_dict = ckp['encoder_state_dict']
            decoder_state_dict = ckp['decoder_state_dict']
            self.encoder.load_state_dict(encoder_state_dict, strict=False)
            self.decoder.load_state_dict(decoder_state_dict, strict=True)
        # === MemoryK
        self.m_sz = self.args.memory_size  # 1024
        self.m_key_dim = self.args.memory_key_dim  # 128
        self.m_K = self.args.memory_K
        self.need_norm = self.args.need_norm
        self.memory = Memory(mem_size=self.m_sz, key_dim=self.m_key_dim, tau=0.95, need_norm=self.need_norm)

        self.trainable_params = chain(
            self.encoder.parameters(), self.memory.parameters(), self.decoder.parameters())
        self.optimizer_all = create_optim(
            optim_name='sgd', lr=1e-3, params=self.trainable_params)
        self.mse_loss = nn.MSELoss(reduction='mean')
        # self.scheduler_encoder = MultiStepLR(
        #     self.optimizer_encoder, milestones=self.args.lr_decay_steps, gamma=0.1)
        self.gamma = 0.2
        self.param_frozen = False  # False default
        self.cur_session = None

        if USE_GPU:
            self.encoder = self.encoder.cuda()
            self.memory = self.memory.cuda()
            self.decoder = self.decoder.cuda()

        self.eval_dataloaders = []
        current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
        self.writer = SummaryWriter(log_dir=osp.join('runs', 'incremental', current_time))
        self.data_manager = DataManager(self.args, use_gpu=USE_GPU)

    def get_dataloader(self, session):
        if session == 0:
            train_loader, eval_loader = self.data_manager.get_dataloaders()
        else:  # Incremental sessions
            if self.dataset == 'miniImageNet':
                train_loader = self.data_manager.get_dataloaders(session=session, is_fewshot=True)
                eval_loader = self.data_manager.get_dataloaders(session=session, is_fewshot=False)
            elif self.dataset in ('cifar100', 'cub-200', 'casia-face', 'cnbc-face'):
                train_loader, eval_loader = self.data_manager.get_dataloaders(
                    session=session, is_fewshot=True)
            else:
                raise NotImplementedError
        return train_loader, eval_loader

    def run(self, start_session=0, end_session=8):
        if start_session > 0:  # Load
            load_dir = osp.join(self.args.save_dir, self.dataset, 'incr', self.args.encoder,
                                'session' + str(start_session - 1), 'best_model.ckp')
            print('Start session > 0, loading checkpoint from:', load_dir)
            ckp = torch.load(load_dir)
            self.encoder.load_state_dict(ckp['encoder_state_dict'])
            self.memory.load_state_dict(ckp['memory_state_dict'])
            self.decoder.load_state_dict(ckp['decoder_state_dict'])
            # Evaluate seen classes
            for passed_session in range(start_session):
                # if passed_session == 0:
                #     _, eval_loader = self.data_manager.get_dataloaders()
                # else:
                #     eval_loader = self.data_manager.get_dataloaders(session=passed_session)
                _, eval_loader = self.get_dataloader(passed_session)
                self.eval_dataloaders.append(eval_loader)
            # self._eval_session(start_session - 1)
        for sess in range(start_session, end_session + 1):
            if sess > 0 and not self.param_frozen:
                for param in self.encoder.parameters():
                    param.requires_grad = False
                for param in self.encoder.layer4.parameters():
                    param.requires_grad = True
                self.trainable_params = chain(
                    self.encoder.parameters(), self.memory.parameters(), self.decoder.parameters())
                self.optimizer_all = create_optim(
                    optim_name='sgd', lr=1e-3,
                    params=filter(lambda p: p.requires_grad, self.trainable_params))
                self.param_frozen = True
                print('Encoder frozen.')
            self._train_session(sess, use_centroid=True)
            self._eval_session(sess, use_centroid=True)

    def _train_session(self, session, use_centroid=False):
        # assert session in range(0, 9)
        print('Training session {}'.format(session))
        self.cur_session = session
        if session > 0:
            self.memory.del_noisy_slots()
            memory_vals = self.memory.m_vals.cpu().numpy()
            memory_vals_counter = Counter(memory_vals)
            print('memory val:', len(memory_vals_counter), memory_vals_counter.most_common())

        # === Data
        print('Preparing data {} with session {}...'.format(self.dataset, session))
        train_loader, eval_loader = self.get_dataloader(session)
        if session == 0:
            # train_loader, eval_loader = self.data_manager.get_dataloaders()
            max_epoch = self.args.max_epoch_sess0
            m_replays = None
        else:
            # if self.dataset == 'miniImageNet':
            #     train_loader = self.data_manager.get_dataloaders(session=session, is_fewshot=True)
            #     eval_loader = self.data_manager.get_dataloaders(session=session, is_fewshot=False)
            # elif self.dataset == 'cifar100' or 'cub-200' or 'casia-face':
            #     train_loader, eval_loader = self.data_manager.get_dataloaders(
            #         session=session, is_fewshot=True)
            # else:
            #     raise NotImplementedError

            # Memory replay data
            m_keys, m_targets = self._get_nonempty_memory_slots()
            # m_keys = m_keys.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 3, 3)  # [nonempty, key_dim, 3, 3]
            m_replays = (m_keys, m_targets)

            max_epoch = 20  # Todo
        print('Num of batches of train loader:', len(train_loader))
        self.eval_dataloaders.append(eval_loader)

        start_time = time.time()
        train_time = 0
        best_epoch = 0
        # best_acc = self._eval_epoch(epoch=0, eval_loader=eval_loader)
        best_acc = -1
        best_state = None
        # === Train
        for epoch in range(1, max_epoch + 1):
            if session == 0:
                if epoch == 1:
                    adjust_lr(self.optimizer_all, 1e-3)
                if epoch > 10:
                    adjust_lr(self.optimizer_all, 1e-4)
            else:
                if epoch == 1:
                    adjust_lr(self.optimizer_all, 1e-4)
                if epoch > 15:
                    adjust_lr(self.optimizer_all, 1e-5)
            cur_lr = self.optimizer_all.param_groups[0]['lr']
            self.writer.add_scalar('Learning rate', cur_lr, global_step=epoch)

            epoch_time_start = time.time()
            self._train_epoch(epoch, cur_lr, train_loader, m_replays)
            train_time += round(time.time() - epoch_time_start)

            # === Eval on current session's dataloader only
            if epoch == self.args.max_epoch_sess0 or epoch % self.args.eval_per_epoch == 0:
                if use_centroid:
                    self.memory.upd_centroids()
                acc = self._eval_epoch(epoch, eval_loader, use_centroid=use_centroid)

                # === Save checkpoint
                is_best = acc > best_acc
                if is_best or epoch == 2 or epoch % self.args.save_per_epoch == 0:
                    state = {
                        'encoder_state_dict': self.encoder.state_dict(),
                        'memory_state_dict': self.memory.state_dict(),
                        'decoder_state_dict': self.decoder.state_dict(),
                        'acc': acc,
                        'session': session,
                        'epoch': epoch,
                    }
                    file_path = osp.join(
                        self.args.save_dir, self.dataset, 'incr', self.args.encoder,
                        'session' + str(session), 'ckp_ep' + str(epoch) + '.ckp')
                    if epoch == 2:
                        pass
                        # save_checkpoint(state, False, file_path)
                    else:
                        save_checkpoint(state, is_best, file_path)
                    if is_best:
                        best_acc = acc
                        best_epoch = epoch
                        best_state = copy.deepcopy(state)
                print('==> Test best accuracy {:.2%}, achieved at epoch {}'.format(
                    best_acc, best_epoch))
            torch.cuda.empty_cache()
        # Load best checkpoint
        self.encoder.load_state_dict(best_state['encoder_state_dict'])
        self.memory.load_state_dict(best_state['memory_state_dict'])
        self.decoder.load_state_dict(best_state['decoder_state_dict'])

        elapsed = str(datetime.timedelta(seconds=round(time.time() - start_time)))
        train_time = str(datetime.timedelta(seconds=train_time))
        print('Session {} finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}'.format(
            session, elapsed, train_time))
        # print("==========\nArgs:{}\n==========".format(self.args))

        memory_vals = self.memory.m_vals.cpu().numpy()
        memory_vals_counter = Counter(memory_vals)
        print('memory val:', len(memory_vals_counter), memory_vals_counter.most_common())

    def _train_epoch(self, epoch, lr, train_loader, memory_replay=None, use_reparam=False):
        self.encoder.train()
        self.memory.train()
        self.decoder.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses_clsf = AverageMeter()
        losses_recons = AverageMeter()
        accs = AverageMeter()

        end = time.time()
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            # print('targets:', len(targets), targets[targets > 54])
            data_time.update(time.time() - end)
            if USE_GPU:
                inputs = inputs.cuda()
                targets = targets.cuda()

            bs = inputs.shape[0]
            if len(inputs.shape) == 5:  # episode, batch size = 1
                inputs = inputs.squeeze(0)  # [25, 3, 84, 84]
                targets = targets.squeeze(0)
                bs = inputs.shape[0]

            # === Encoder & Decoder forward
            outputs, fms = self.encoder(inputs, return_fm=True)
            img_recons = self.decoder(fms[3], scale_factor=self.decoder_scale_factor,
                                      out_size=self.decoder_output_size)
            if self.dsae:
                loss_recons = self.mse_loss(img_recons, fms[2])  # reconstruction loss
            else:
                loss_recons = self.mse_loss(img_recons, inputs)

            # === MemoryK forward
            # loss_e = torch.tensor(0.)
            preds, loss_memory = self.memory(outputs, targets)
            acc = preds.eq(targets).sum().float() / bs
            accs.update(acc.item(), bs)

            loss_all = self.gamma * loss_memory + (1 - self.gamma) * loss_recons
            self.optimizer_all.zero_grad()
            loss_all.backward()
            self.optimizer_all.step()

            if batch_idx % 90 == 0:
                print(batch_idx, '; memory loss:', loss_memory.item(),
                      '; decoder loss:', loss_recons.item())
            losses_clsf.update(loss_memory.item(), bs)
            losses_recons.update(loss_recons.item(), bs)
            batch_time.update(time.time() - end)
            end = time.time()

        # memory_replay = None
        if memory_replay is not None:
            with torch.no_grad():
                m_inputs, m_targets = memory_replay  # [nonempty, key_dim]
                # m_inputs = m_inputs.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 3, 3)  # [nonempty, key_dim, 3, 3]
                m_inputs_aug = []
                m_targets_aug = []
                n_classes = 60 + self.args.n_novel * (self.cur_session - 1)

                rand_C_classes = np.random.choice(n_classes, self.args.n_novel, replace=False)
                for v in rand_C_classes:  # rand_C_classes:  # range(n_classes) for all classes
                    m_inputs_v = m_inputs[torch.eq(m_targets, v)]
                    # print('m_inputs_v:', m_inputs_v.shape)

                    if use_reparam:  # re-parameterize
                        m_mean_v = m_inputs_v.mean(dim=0)
                        m_std_v = m_inputs_v.std(dim=0)
                        for i in range(self.args.n_shot * 2):
                            v_aug = torch.normal(mean=m_mean_v, std=m_std_v)
                            # print('v_aug:', v_aug.shape)
                            m_inputs_aug.append(v_aug)
                            m_targets_aug.append(v)
                    else:  # random sample
                        n_v = m_inputs_v.size(0)
                        if n_v == 0:
                            continue
                        for i in range(self.args.n_shot):
                            rand_idxs = np.random.choice(n_v, 3, replace=True)
                            rand_w = F.normalize(torch.rand([3]), p=1, dim=0)
                            v_aug = (rand_w[0] * m_inputs_v[rand_idxs[0]] +
                                     rand_w[1] * m_inputs_v[rand_idxs[1]] + rand_w[2] * m_inputs_v[rand_idxs[2]])
                            m_inputs_aug.append(v_aug)
                            m_targets_aug.append(v)
                m_inputs = torch.stack(m_inputs_aug, dim=0)
                m_inputs = m_inputs.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 3, 3)
                if self.need_norm:
                    m_inputs = F.normalize(m_inputs, p=2, dim=1)
                m_targets = torch.tensor(m_targets_aug, dtype=torch.long)
                # Shuffle
                sfl_idxs = torch.randperm(m_inputs.size(0))
                m_inputs = m_inputs[sfl_idxs]
                m_targets = m_targets[sfl_idxs]
                print('Memory replay size:', m_inputs.size(0))

                m_inputs = self.decoder(m_inputs, scale_factor=self.decoder_scale_factor,
                                        out_size=self.decoder_output_size)
                batch_size = 128
                n_sample = m_targets.size(0)
                n_batch = math.ceil(n_sample / batch_size)
                inputs = m_inputs.chunk(chunks=n_batch, dim=0)
                targets = m_targets.chunk(chunks=n_batch, dim=0)
                print('After chunk, inputs:', inputs[0].shape, '; targets:', targets[0].shape)
                m_train_loader = list(zip(inputs, targets))
            for batch_idx, (inputs, targets) in enumerate(m_train_loader):
                data_time.update(time.time() - end)
                if USE_GPU:
                    inputs = inputs.cuda()
                    targets = targets.cuda()

                # === Encoder & Decoder forward
                outputs = self.encoder(inputs, return_fm=False, feed_fm=self.dsae)
                img_recons = self.decoder(outputs, scale_factor=self.decoder_scale_factor,
                                          out_size=self.decoder_output_size)
                loss_recons = self.mse_loss(img_recons, inputs)

                # === MemoryK forward
                preds, loss_memory = self.memory(outputs, targets, upd_memory=False)
                loss_all = self.gamma * loss_memory + (1 - self.gamma) * loss_recons
                self.optimizer_all.zero_grad()
                loss_all.backward()
                self.optimizer_all.step()
        acc_avg = accs.avg
        loss_c_avg = losses_clsf.avg
        loss_r_avg = losses_recons.avg
        self.writer.add_scalar('Loss/train/Classification', loss_c_avg, global_step=epoch)
        self.writer.add_scalar('Loss/train/Reconstruction', loss_r_avg, global_step=epoch)
        print(
            '-Train- Epoch: {}, Lr: {:.6f}, Time: {:.1f}s, Data: {:.1f}s, '
            'Loss(C|R): {:.4f} | {:.4f}, Acc: {:.2%}'.format(
                epoch, lr, batch_time.sum, data_time.sum, loss_c_avg, loss_r_avg, acc_avg))

    @torch.no_grad()
    def _eval_session(self, session, use_centroid=False):
        assert len(self.eval_dataloaders) == session + 1

        if use_centroid:
            self.memory.upd_centroids()
        accuracies = []
        for sess in range(session + 1):
            eval_loader_sess = self.eval_dataloaders[sess]
            acc_sess = self._eval_epoch(epoch=None, eval_loader=eval_loader_sess, use_centroid=use_centroid)
            accuracies.append(acc_sess)
        acc_sum = AverageMeter()
        for sess in range(session + 1):
            acc = accuracies[sess]
            if sess == 0:
                n_cls = 60  # self.args.n_class
            else:
                n_cls = self.args.n_novel
            acc_sum.update(acc, n_cls)
        print('Session {} Evaluation. Overall Acc.: {}'.format(session, acc_sum.avg))

    @torch.no_grad()
    def _eval_epoch(self, epoch, eval_loader, use_centroid=False):
        self.encoder.eval()
        self.memory.eval()
        self.decoder.eval()

        accs = AverageMeter()
        losses_clsf = AverageMeter()
        losses_recons = AverageMeter()
        for batch_idx, (inputs, targets) in enumerate(eval_loader):
            if USE_GPU:
                inputs = inputs.cuda()
                targets = targets.cuda()
            outputs, fms = self.encoder(inputs, return_fm=True)
            # print('outputs:', outputs.shape)

            # Decoder
            img_recons = self.decoder(fms[3], scale_factor=self.decoder_scale_factor,
                                      out_size=self.decoder_output_size)
            if self.dsae:
                loss_recons = self.mse_loss(img_recons, fms[2])  # reconstruction loss
            else:
                loss_recons = self.mse_loss(img_recons, inputs)

            preds, loss_memory = self.memory(outputs, targets, use_centroid=use_centroid)
            losses_clsf.update(loss_memory.item(), targets.size(0))
            losses_recons.update(loss_recons.item(), targets.size(0))
            acc = preds.eq(targets).sum().float() / targets.size(0)
            accs.update(acc.item(), targets.size(0))

        acc_avg = accs.avg
        loss_c_avg = losses_clsf.avg
        loss_r_avg = losses_recons.avg
        if epoch is not None:
            self.writer.add_scalar('Loss/eval/Classification', loss_c_avg, global_step=epoch)
            self.writer.add_scalar('Loss/eval/Reconstruction', loss_r_avg, global_step=epoch)
            self.writer.add_scalar('Accuracy/eval', acc_avg, global_step=epoch)
            print('-Eval- Epoch: {}, Loss(C|R): {:.4f} | {:.4f}, Accuracy: {:.2%}'.format(
                epoch, loss_c_avg, loss_r_avg, acc_avg))

        return acc_avg

    @torch.no_grad()
    def _get_nonempty_memory_slots(self):
        nonempty_idxs = torch.where(self.memory.m_vals != -1)
        m_keys = self.memory.m_keys[nonempty_idxs]  # [nonempty, key_dim]
        # m_keys = m_keys.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 3, 3)  # [nonempty, key_dim, 3, 3]
        m_vals = self.memory.m_vals[nonempty_idxs]  # [nonempty]

        return m_keys, m_vals
Beispiel #2
0
    def __init__(self):
        self.args = ArgumentManager().get_args(parser_type='pre-train')
        torch.manual_seed(self.args.seed)
        print('==============\nArgs:{}\n=============='.format(self.args))

        if USE_GPU:
            print('Currently using GPU: {}-{}'.format(
                torch.cuda.current_device(), torch.cuda.get_device_name()))
            cudnn.benchmark = True
            torch.cuda.manual_seed_all(self.args.seed)
        else:
            print('Currently using CPU (GPU is highly recommended)')

        self.dsae = True  # True default

        # === Data
        self.dataset = self.args.dataset
        if self.dataset in ('miniImageNet', 'cnbc-face', 'cub-200'):
            self.decoder_output_size = 6 if self.dsae else 84
            self.decoder_scale_factor = 3
        elif self.dataset == 'cifar100':
            self.decoder_output_size = 4 if self.dsae else 32
            self.decoder_scale_factor = 2
        elif self.dataset == 'casia-face':
            self.decoder_output_size = 8 if self.dsae else 128
            self.decoder_scale_factor = 3
        else:
            raise NotImplementedError
        print('Preparing data {}...'.format(self.dataset))
        data_manager = DataManager(self.args, use_gpu=USE_GPU)
        self.train_loader, self.eval_loader = data_manager.get_dataloaders()
        print('len of train loader:', len(self.train_loader))  # 1673

        # === Models
        self.encoder = create_encoder(self.args, is_snail=False)
        print('Encoder:', self.encoder)

        if self.dsae:
            self.decoder = create_decoder(self.args, out_dim=256, fm_level=3)
        else:
            self.decoder = create_decoder(self.args, out_dim=3, fm_level=-1)
        print('Decoder:', self.decoder)

        # Load checkpoint
        if self.args.dataset == 'cnbc-face':
            print('Loading checkpoint...')
            ckp = torch.load(r'')
            del ckp['encoder_state_dict']['downsample.0.weight'], ckp['encoder_state_dict']['classifier.weight']
            self.encoder.load_state_dict(ckp['encoder_state_dict'], strict=False)
            self.decoder.load_state_dict(ckp['decoder_state_dict'], strict=False)

        self.trainable_params = chain(self.encoder.parameters(), self.decoder.parameters())
        self.lr = self.args.lr
        self.optimizer = create_optim(args=self.args, params=self.trainable_params)
        self.criterion_ce = nn.CrossEntropyLoss()
        self.criterion_mse = nn.MSELoss()
        self.scheduler = MultiStepLR(self.optimizer, milestones=self.args.lr_decay_steps, gamma=0.1)
        if USE_GPU:
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()

        current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
        self.writer = SummaryWriter(log_dir=os.path.join('runs', 'pretrain', current_time))
        self.gamma = 0.5