コード例 #1
0
ファイル: algorithm.py プロジェクト: phymucs/sib_meta_learn
    def validate(self, valLoader, lr=None, mode='val'):
        if mode == 'test':
            nEpisode = self.nEpisode
            self.logger.info(
                '\n\nTest mode: randomly sample {:d} episodes...'.format(
                    nEpisode))
        elif mode == 'val':
            nEpisode = len(valLoader)
            self.logger.info(
                '\n\nValidation mode: pre-defined {:d} episodes...'.format(
                    nEpisode))
            valLoader = iter(valLoader)
        else:
            raise ValueError('mode is wrong!')

        episodeAccLog = []
        top1 = AverageMeter()

        self.netFeat.eval()
        #self.netSIB.eval() # set train mode, since updating bn helps to estimate better gradient

        if lr is None:
            lr = self.optimizer.param_groups[0]['lr']

        #for batchIdx, data in enumerate(valLoader):
        for batchIdx in range(nEpisode):
            data = valLoader.getEpisode() if mode == 'test' else next(
                valLoader)
            data = to_device(data, self.device)

            SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                    data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \
                    data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0)

            with torch.no_grad():
                SupportFeat, QueryFeat = self.netFeat(
                    SupportTensor), self.netFeat(QueryTensor)
                SupportFeat, QueryFeat, SupportLabel = \
                        SupportFeat.unsqueeze(0), QueryFeat.unsqueeze(0), SupportLabel.unsqueeze(0)

            clsScore = self.netSIB(lr, SupportFeat, SupportLabel, QueryFeat)
            clsScore = clsScore.view(QueryFeat.size()[0] * QueryFeat.size()[1],
                                     -1)
            QueryLabel = QueryLabel.view(-1)
            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.size()[0])

            msg = 'Top1: {:.3f}%'.format(top1.avg)
            progress_bar(batchIdx, nEpisode, msg)
            episodeAccLog.append(acc1[0].item())

        mean, ci95 = getCi(episodeAccLog)
        self.logger.info(
            'Final Perf with 95% confidence intervals: {:.3f}%, {:.3f}%'.
            format(mean, ci95))
        return mean, ci95
コード例 #2
0
    def test(self, epoch):
        msg = '\nTest at Epoch: {:d}'.format(epoch)
        print(msg)

        self.netFeat.eval()
        self.netClassifierVal.eval()

        top1 = AverageMeter()

        for batchIdx, data in enumerate(self.valLoader):
            data = to_device(data, self.device)

            SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \
                data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0)

            SupportFeat, QueryFeat = self.netFeat(SupportTensor), self.netFeat(
                QueryTensor)
            SupportFeat, QueryFeat = SupportFeat.unsqueeze(
                0), QueryFeat.unsqueeze(0)

            clsScore = self.netClassifierVal(SupportFeat, QueryFeat)
            clsScore = clsScore.view(QueryFeat.size()[1], -1)

            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.size()[0])
            msg = 'Top1: {:.3f}%'.format(top1.avg)
            progress_bar(batchIdx, len(self.valLoader), msg)

        ## Save checkpoint.
        acc = top1.avg
        if acc > self.bestAcc:
            print('Saving Best')
            torch.save(self.netFeat.state_dict(),
                       os.path.join(self.outDir, 'netFeatBest.pth'))
            torch.save(self.netClassifier.state_dict(),
                       os.path.join(self.outDir, 'netClsBest.pth'))
            self.bestAcc = acc

        print('Saving Last')
        torch.save(self.netFeat.state_dict(),
                   os.path.join(self.outDir, 'netFeatLast.pth'))
        torch.save(self.netClassifier.state_dict(),
                   os.path.join(self.outDir, 'netClsLast.pth'))

        msg = 'Best Performance: {:.3f}'.format(self.bestAcc)
        print(msg)
        return top1.avg
コード例 #3
0
ファイル: runner_sib.py プロジェクト: qianrusun1015/E3BM-1
    def validate(self, valLoader, lr=None):

        nEpisode = self.nEpisode
        self.logger.info(
            '\n\nTest mode: randomly sample {:d} episodes...'.format(nEpisode))

        episodeAccLog = []
        top1 = AverageMeter()

        self.netFeat.eval()

        if lr is None:
            lr = self.optimizer.param_groups[0]['lr']

        for batchIdx in range(nEpisode):
            data = valLoader.getEpisode()
            data = to_device(data, self.device)

            SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                    data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \
                    data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0)

            with torch.no_grad():
                SupportFeat, QueryFeat = self.netFeat(
                    SupportTensor), self.netFeat(QueryTensor)
                SupportFeat, QueryFeat, SupportLabel = \
                        SupportFeat.unsqueeze(0), QueryFeat.unsqueeze(0), SupportLabel.unsqueeze(0)

            clsScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat, lr)
            clsScore = clsScore.view(QueryFeat.shape[0] * QueryFeat.shape[1],
                                     -1)
            QueryLabel = QueryLabel.view(-1)
            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.shape[0])

            msg = 'Top1: {:.3f}%'.format(top1.avg)
            progress_bar(batchIdx, nEpisode, msg)
            episodeAccLog.append(acc1[0].item())

        mean, ci95 = getCi(episodeAccLog)
        self.logger.info(
            'Final Perf with 95% confidence intervals: {:.3f}%, {:.3f}%'.
            format(mean, ci95))
        return mean, ci95
コード例 #4
0
    def train(self, epoch):
        msg = '\nTrain at Epoch: {:d}'.format(epoch)
        print(msg)

        self.netFeat.train()
        self.netClassifier.train()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        for batchIdx, (inputs, targets) in enumerate(self.trainLoader):

            inputs = to_device(inputs, self.device)
            targets = to_device(targets, self.device)

            self.optimizer.zero_grad()
            outputs = self.netFeat(inputs)
            outputs = self.netClassifier(outputs)
            loss = self.criterion(outputs, targets)

            loss.backward()
            self.optimizer.step()

            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size()[0])
            top1.update(acc1[0].item(), inputs.size()[0])
            top5.update(acc5[0].item(), inputs.size()[0])

            msg = 'Loss: {:.3f} | Top1: {:.3f}% | Top5: {:.3f}%'.format(
                losses.avg, top1.avg, top5.avg)
            progress_bar(batchIdx, len(self.trainLoader), msg)

        return losses.avg, top1.avg, top5.avg
コード例 #5
0
    def LrWarmUp(self, totalIter, lr):
        msg = '\nLearning rate warming up'
        print(msg)

        self.optimizer = torch.optim.SGD(itertools.chain(
            *[self.netFeat.parameters(),
              self.netClassifier.parameters()]),
                                         1e-7,
                                         momentum=0.9,
                                         weight_decay=5e-4,
                                         nesterov=True)

        nbIter = 0
        lrUpdate = lr
        valTop1 = 0

        while nbIter < totalIter:
            self.netFeat.train()
            self.netClassifier.train()
            losses = AverageMeter()
            top1 = AverageMeter()
            top5 = AverageMeter()

            for batchIdx, (inputs, targets) in enumerate(self.trainLoader):
                nbIter += 1
                if nbIter == totalIter:
                    break

                lrUpdate = nbIter / float(totalIter) * lr
                for g in self.optimizer.param_groups:
                    g['lr'] = lrUpdate

                inputs = to_device(inputs, self.device)
                targets = to_device(targets, self.device)

                self.optimizer.zero_grad()
                outputs = self.netFeat(inputs)
                outputs = self.netClassifier(outputs)
                loss = self.criterion(outputs, targets)

                loss.backward()
                self.optimizer.step()

                acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
                losses.update(loss.item(), inputs.size()[0])
                top1.update(acc1[0].item(), inputs.size()[0])
                top5.update(acc5[0].item(), inputs.size()[0])

                msg = 'Loss: {:.3f} | Lr : {:.5f} | Top1: {:.3f}% | Top5: {:.3f}%'.format(
                    losses.avg, lrUpdate, top1.avg, top5.avg)
                progress_bar(batchIdx, len(self.trainLoader), msg)

        with torch.no_grad():
            valTop1 = self.test(0)

        self.optimizer = torch.optim.SGD(itertools.chain(
            *[self.netFeat.parameters(),
              self.netClassifier.parameters()]),
                                         lrUpdate,
                                         momentum=0.9,
                                         weight_decay=5e-4,
                                         nesterov=True)

        self.lrScheduler = MultiStepLR(self.optimizer,
                                       milestones=self.milestones,
                                       gamma=0.1)
        return valTop1
コード例 #6
0
ファイル: runner_sib.py プロジェクト: qianrusun1015/E3BM-1
    def train(self, trainLoader, valLoader, lr=None, coeffGrad=0.0):
        bestAcc, ci = self.validate(valLoader, lr)
        self.logger.info(
            'Acc improved over validation set from 0% ---> {:.3f} +- {:.3f}%'.
            format(bestAcc, ci))

        self.netSIB.train()
        self.netFeat.eval()

        losses = AverageMeter()
        top1 = AverageMeter()
        history = {'trainLoss': [], 'trainAcc': [], 'valAcc': []}

        for episode in range(self.nbIter):
            data = trainLoader.getBatch()
            data = to_device(data, self.device)

            with torch.no_grad():
                SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                        data['SupportTensor'], data['SupportLabel'], data['QueryTensor'], data['QueryLabel']
                nC, nH, nW = SupportTensor.shape[2:]

                SupportFeat = self.netFeat(
                    SupportTensor.reshape(-1, nC, nH, nW))
                SupportFeat = SupportFeat.view(self.batchSize, -1, self.nFeat)

                QueryFeat = self.netFeat(QueryTensor.reshape(-1, nC, nH, nW))
                QueryFeat = QueryFeat.view(self.batchSize, -1, self.nFeat)

            if lr is None:
                lr = self.optimizer.param_groups[0]['lr']

            self.optimizer.zero_grad()

            clsScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat, lr)
            clsScore = clsScore.view(QueryFeat.shape[0] * QueryFeat.shape[1],
                                     -1)
            QueryLabel = QueryLabel.view(-1)

            if coeffGrad > 0:
                loss, gradLoss = self.compute_grad_loss(clsScore, QueryLabel)
                loss = loss + gradLoss * coeffGrad
            else:
                loss = self.criterion(clsScore, QueryLabel)

            loss.backward()
            self.optimizer.step()

            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.shape[0])
            losses.update(loss.item(), QueryFeat.shape[1])
            msg = 'Loss: {:.3f} | Top1: {:.3f}% '.format(losses.avg, top1.avg)
            if coeffGrad > 0:
                msg = msg + '| gradLoss: {:.3f}%'.format(gradLoss.item())
            progress_bar(episode, self.nbIter, msg)

            if episode % 1000 == 999:
                acc, _ = self.validate(valLoader, lr)

                if acc > bestAcc:
                    msg = 'Acc improved over validation set from {:.3f}% ---> {:.3f}%'.format(
                        bestAcc, acc)
                    self.logger.info(msg)

                    bestAcc = acc
                    self.logger.info('Saving Best')
                    torch.save(
                        {
                            'lr': lr,
                            'netFeat': self.netFeat.state_dict(),
                            'SIB': self.netSIB.state_dict(),
                            'nbStep': self.nStep,
                        }, os.path.join(self.outDir, 'netSIBBest.pth'))

                self.logger.info('Saving Last')
                torch.save(
                    {
                        'lr': lr,
                        'netFeat': self.netFeat.state_dict(),
                        'SIB': self.netSIB.state_dict(),
                        'nbStep': self.nStep,
                    }, os.path.join(self.outDir, 'netSIBLast.pth'))

                msg = 'Iter {:d}, Train Loss {:.3f}, Train Acc {:.3f}%, Val Acc {:.3f}%'.format(
                    episode, losses.avg, top1.avg, acc)
                self.logger.info(msg)
                history['trainLoss'].append(losses.avg)
                history['trainAcc'].append(top1.avg)
                history['valAcc'].append(acc)

                losses = AverageMeter()
                top1 = AverageMeter()

        return bestAcc, acc, history
コード例 #7
0
ファイル: algorithm.py プロジェクト: fzohra/despur
    def train(self, trainLoader, valLoader, lr=None, coeffGrad=0.0):
        """
        Run one epoch on train-set.

        :param trainLoader: the dataloader of train-set
        :type trainLoader: class `TrainLoader`
        :param valLoader: the dataloader of val-set
        :type valLoader: class `ValLoader`
        :param float lr: learning rate for synthetic GD
        :param float coeffGrad: deprecated
        """
        bestAcc, ci = self.validate(valLoader, lr, 'test')
        self.logger.info(
            'Acc improved over validation set from 0% ---> {:.3f} +- {:.3f}%'.
            format(bestAcc, ci))

        self.netSIB.train()
        self.netFeat.eval()

        losses = AverageMeter()
        top1 = AverageMeter()
        history = {'trainLoss': [], 'trainAcc': [], 'valAcc': []}

        for episode in range(self.nbIter):
            data = trainLoader.getBatch()
            data = to_device(data, self.device)

            with torch.no_grad():
                SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                        data['SupportTensor'], data['SupportLabel'], data['QueryTensor'], data['QueryLabel']
                nC, nH, nW = SupportTensor.shape[2:]

                # SupportFeat = self.netFeat(SupportTensor.reshape(-1, nC, nH, nW))
                SupportFeat = self.pretrain.get_features(
                    SupportTensor.reshape(-1, nC, nH, nW))
                SupportFeat = SupportFeat.view(self.batchSize, -1, self.nFeat)

                # QueryFeat = self.netFeat(QueryTensor.reshape(-1, nC, nH, nW))
                QueryFeat = self.pretrain.get_features(
                    QueryTensor.reshape(-1, nC, nH, nW))
                QueryFeat = QueryFeat.view(self.batchSize, -1, self.nFeat)

            if lr is None:
                lr = self.optimizer.param_groups[0]['lr']

            self.optimizer.zero_grad()

            clsScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat, lr)
            clsScore = clsScore.view(QueryFeat.shape[0] * QueryFeat.shape[1],
                                     -1)

            # Inductive
            '''
            clsScore = torch.zeros(QueryFeat.shape[1], 5).cuda()
            for i in range(QueryFeat.shape[1]):
                singleScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat[:, i, :].unsqueeze(1), lr)
                clsScore[i] = singleScore[0][0]
            '''

            QueryLabel = QueryLabel.view(-1)

            if coeffGrad > 0:
                loss, gradLoss = self.compute_grad_loss(clsScore, QueryLabel)
                loss = loss + gradLoss * coeffGrad
            else:
                loss = self.criterion(clsScore, QueryLabel)

            loss.backward()
            self.optimizer.step()

            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.shape[0])
            losses.update(loss.item(), QueryFeat.shape[1])
            msg = 'Loss: {:.3f} | Top1: {:.3f}% '.format(losses.avg, top1.avg)
            if coeffGrad > 0:
                msg = msg + '| gradLoss: {:.3f}%'.format(gradLoss.item())
            progress_bar(episode, self.nbIter, msg)

            if episode % 1000 == 999:
                acc, _ = self.validate(valLoader, lr, 'test')

                if acc > bestAcc:
                    msg = 'Acc improved over validation set from {:.3f}% ---> {:.3f}%'.format(
                        bestAcc, acc)
                    self.logger.info(msg)

                    bestAcc = acc
                    self.logger.info('Saving Best')
                    torch.save(
                        {
                            'lr': lr,
                            'netFeat': self.netFeat.state_dict(),
                            'SIB': self.netSIB.state_dict(),
                            'nbStep': self.nStep,
                        }, os.path.join(self.outDir, 'netSIBBest.pth'))

                self.logger.info('Saving Last')
                torch.save(
                    {
                        'lr': lr,
                        'netFeat': self.netFeat.state_dict(),
                        'SIB': self.netSIB.state_dict(),
                        'nbStep': self.nStep,
                    }, os.path.join(self.outDir, 'netSIBLast.pth'))

                msg = 'Iter {:d}, Train Loss {:.3f}, Train Acc {:.3f}%, Val Acc {:.3f}%, Best Acc {:.3f}'.format(
                    episode, losses.avg, top1.avg, acc, bestAcc)
                self.logger.info(msg)
                self.write_output_message(msg)
                history['trainLoss'].append(losses.avg)
                history['trainAcc'].append(top1.avg)
                history['valAcc'].append(acc)

                losses = AverageMeter()
                top1 = AverageMeter()

        return bestAcc, acc, history
コード例 #8
0
ファイル: algorithm.py プロジェクト: fzohra/despur
    def validate(self, valLoader, lr=None, mode='val'):
        """
        Run one epoch on val-set.
        :param valLoader: the dataloader of val-set
        :type valLoader: class `ValLoader`
        :param float lr: learning rate for synthetic GD
        :param string mode: 'val' or 'train'
        """
        if mode == 'test':
            nEpisode = self.nEpisode
            self.logger.info(
                '\n\nTest mode: randomly sample {:d} episodes...'.format(
                    nEpisode))
        elif mode == 'val':
            nEpisode = len(valLoader)
            self.logger.info(
                '\n\nValidation mode: pre-defined {:d} episodes...'.format(
                    nEpisode))
            valLoader = iter(valLoader)
        else:
            raise ValueError('mode is wrong!')

        episodeAccLog = []
        top1 = AverageMeter()

        self.netFeat.eval()
        #self.netSIB.eval() # set train mode, since updating bn helps to estimate better gradient

        if lr is None:
            lr = self.optimizer.param_groups[0]['lr']

        #for batchIdx, data in enumerate(valLoader):
        # nEpisode = 1
        for batchIdx in range(nEpisode):
            data = valLoader.getEpisode() if mode == 'test' else next(
                valLoader)
            data = to_device(data, self.device)

            SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                    data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \
                    data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0)

            with torch.no_grad():
                # SupportFeat, QueryFeat = self.netFeat(SupportTensor), self.netFeat(QueryTensor)
                SupportFeat, QueryFeat = self.pretrain.get_features(
                    SupportTensor), self.pretrain.get_features(QueryTensor)
                SupportFeat, QueryFeat, SupportLabel = \
                        SupportFeat.unsqueeze(0), QueryFeat.unsqueeze(0), SupportLabel.unsqueeze(0)

            clsScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat, lr)
            clsScore = clsScore.view(QueryFeat.shape[0] * QueryFeat.shape[1],
                                     -1)

            # Inductive
            '''
            clsScore = torch.zeros(QueryFeat.shape[1], 5).cuda()
            for i in range(QueryFeat.shape[1]):
                singleScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat[:, i, :].unsqueeze(1), lr)
                clsScore[i] = singleScore[0][0]
            '''

            QueryLabel = QueryLabel.view(-1)

            if self.davg:
                # diff_scores = self.calc_diff_scores(self.pretrain, SupportFeat.squeeze(0), QueryFeat.squeeze(0), SupportLabel.squeeze(0), QueryLabel)  # cosine similarity
                diff_scores = self._evaluate_hardness_logodd(
                    self.pretrain,
                    SupportFeat.squeeze(0), QueryFeat.squeeze(0),
                    SupportLabel.squeeze(0), QueryLabel)  # logodd
            else:
                diff_scores = None
            acc1 = accuracy(clsScore,
                            QueryLabel,
                            topk=(1, ),
                            diff_scores=diff_scores)
            top1.update(acc1[0].item(), clsScore.shape[0])

            msg = 'Top1: {:.3f}%'.format(top1.avg)
            progress_bar(batchIdx, nEpisode, msg)
            episodeAccLog.append(acc1[0].item())

        mean, ci95 = getCi(episodeAccLog)
        msg = 'Final Perf with 95% confidence intervals: {:.3f}%, {:.3f}%'.format(
            mean, ci95)
        self.logger.info(msg)
        self.write_output_message(msg)
        return mean, ci95
コード例 #9
0
    def validate(self, valLoader, mode='val'):
        if mode == 'test':
            nEpisode = self.nEpisode
            self.logger.info(
                '\n\nTest mode: randomly sample {:d} episodes...'.format(
                    nEpisode))
        elif mode == 'val':
            nEpisode = len(valLoader)
            self.logger.info(
                '\n\nValidation mode: pre-defined {:d} episodes...'.format(
                    nEpisode))
            valLoader = iter(valLoader)
        else:
            raise ValueError('mode is wrong!')

        episodeAccLog = []
        top1 = AverageMeter()

        self.netFeat.eval()
        self.netRefine.eval()
        self.netClassifier.eval()

        #for batchIdx, data in enumerate(valLoader):
        for batchIdx in range(nEpisode):
            data = valLoader.getEpisode() if mode == 'test' else next(
                valLoader)
            data = to_device(data, self.device)

            SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                    data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \
                    data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0)

            with torch.no_grad():
                SupportFeat, QueryFeat = self.netFeat(
                    SupportTensor), self.netFeat(QueryTensor)
                SupportFeat, QueryFeat, SupportLabel = \
                        SupportFeat.unsqueeze(0), QueryFeat.unsqueeze(0), SupportLabel.unsqueeze(0)
                nbSupport, nbQuery = SupportFeat.size()[1], QueryFeat.size()[1]

                feat = torch.cat((SupportFeat, QueryFeat), dim=1)
                refine_feat = self.netRefine(feat)
                refine_feat = feat + refine_feat
                refine_support, refine_query = refine_feat.narrow(
                    1, 0,
                    nbSupport), refine_feat.narrow(1, nbSupport, nbQuery)
                clsScore = self.netClassifier(refine_support, SupportLabel,
                                              refine_query)
                clsScore = clsScore.squeeze(0)
            QueryLabel = QueryLabel.view(-1)
            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.size()[0])

            msg = 'Top1: {:.3f}%'.format(top1.avg)
            progress_bar(batchIdx, nEpisode, msg)
            episodeAccLog.append(acc1[0].item())

        mean, ci95 = getCi(episodeAccLog)
        self.logger.info(
            'Final Perf with 95% confidence intervals: {:.3f}%, {:.3f}%'.
            format(mean, ci95))

        self.netRefine.train()
        self.netClassifier.train()
        return mean, ci95
コード例 #10
0
    def train(self, trainLoader, valLoader, lr=None, coeffGrad=0.0):
        bestAcc, ci = self.validate(valLoader)
        self.logger.info(
            'Acc improved over validation set from 0% ---> {:.3f} +- {:.3f}%'.
            format(bestAcc, ci))

        self.netRefine.train()
        self.netFeat.eval()

        losses = AverageMeter()
        top1 = AverageMeter()
        history = {'trainLoss': [], 'trainAcc': [], 'valAcc': []}

        for episode in range(self.nbIter):
            data = trainLoader.getBatch()
            data = to_device(data, self.device)

            with torch.no_grad():
                SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                        data['SupportTensor'], data['SupportLabel'], data['QueryTensor'], data['QueryLabel']

                SupportFeat = self.netFeat(SupportTensor.contiguous().view(
                    -1, 3, self.inputW, self.inputH))
                QueryFeat = self.netFeat(QueryTensor.contiguous().view(
                    -1, 3, self.inputW, self.inputH))

                SupportFeat, QueryFeat = SupportFeat.contiguous().view(self.batchSize, -1, self.nFeat), \
                        QueryFeat.contiguous().view(self.batchSize, -1, self.nFeat)

            self.optimizer.zero_grad()

            nbSupport, nbQuery = SupportFeat.size()[1], QueryFeat.size()[1]
            feat = torch.cat((SupportFeat, QueryFeat), dim=1)
            refine_feat = self.netRefine(feat)
            refine_feat = feat + refine_feat
            refine_support, refine_query = refine_feat.narrow(
                1, 0, nbSupport), refine_feat.narrow(1, nbSupport, nbQuery)
            clsScore = self.netClassifier(refine_support, SupportLabel,
                                          refine_query)

            clsScore = clsScore.view(
                refine_query.size()[0] * refine_query.size()[1], -1)
            QueryLabel = QueryLabel.view(-1)

            loss = self.criterion(clsScore, QueryLabel)

            loss.backward()
            self.optimizer.step()

            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.size()[0])
            losses.update(loss.item(), QueryFeat.size()[1])
            msg = 'Loss: {:.3f} | Top1: {:.3f}% '.format(losses.avg, top1.avg)
            progress_bar(episode, self.nbIter, msg)

            if episode % 1000 == 999:
                acc, _ = self.validate(valLoader)

                if acc > bestAcc:
                    msg = 'Acc improved over validation set from {:.3f}% ---> {:.3f}%'.format(
                        bestAcc, acc)
                    self.logger.info(msg)

                    bestAcc = acc
                    self.logger.info('Saving Best')
                    torch.save(
                        {
                            'lr': lr,
                            'netFeat': self.netFeat.state_dict(),
                            'netRefine': self.netRefine.state_dict(),
                            'netClassifier': self.netClassifier.state_dict(),
                        }, os.path.join(self.outDir, 'netBest.pth'))

                self.logger.info('Saving Last')
                torch.save(
                    {
                        'lr': lr,
                        'netFeat': self.netFeat.state_dict(),
                        'netRefine': self.netRefine.state_dict(),
                        'netClassifier': self.netClassifier.state_dict(),
                    }, os.path.join(self.outDir, 'netLast.pth'))

                msg = 'Iter {:d}, Train Loss {:.3f}, Train Acc {:.3f}%, Val Acc {:.3f}%'.format(
                    episode, losses.avg, top1.avg, acc)
                self.logger.info(msg)
                history['trainLoss'].append(losses.avg)
                history['trainAcc'].append(top1.avg)
                history['valAcc'].append(acc)

                losses = AverageMeter()
                top1 = AverageMeter()

        return bestAcc, acc, history