コード例 #1
0
class PosNegMixupTrainer(object):
    def __init__(self, cfg, model, train_dl, val_dl, loss_func, num_query,
                 num_gpus):
        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.loss_func = loss_func
        self.num_query = num_query

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.train_epoch = 1
        self.batch_cnt = 0

        self.logger = logging.getLogger('reid_baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR
        self.device = cfg.MODEL.DEVICE
        self.epochs = cfg.SOLVER.MAX_EPOCHS

        if num_gpus > 1:

            # Multi-GPU model without FP16
            self.model = nn.DataParallel(self.model)
            if cfg.SOLVER.SYNCBN:
                # convert to use sync_bn
                self.logger.info(
                    'More than one gpu used, convert model to use SyncBN.')
                self.model = convert_model(self.model)
                self.logger.info('Using pytorch SyncBN implementation')
            self.model.cuda()

            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_EPOCH,
                                               cfg.SOLVER.WARMUP_METHOD)

            self.mix_precision = False
            self.logger.info(self.model)
            self.logger.info(self.optim)
            self.logger.info('Trainer Built')
            return

        else:
            # Single GPU model
            self.model.cuda()

            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_EPOCH,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.model)
            self.logger.info(self.optim)
            self.mix_precision = False
            return

    def handle_new_batch(self):
        self.batch_cnt += 1
        if self.batch_cnt % self.cfg.SOLVER.LOG_PERIOD == 0:
            self.logger.info('Epoch[{}] Iteration[{}/{}] Loss: {:.3f},'
                             'Acc: {:.3f}, Base Lr: {:.2e}'.format(
                                 self.train_epoch, self.batch_cnt,
                                 len(self.train_dl), self.loss_avg.avg,
                                 self.acc_avg.avg,
                                 self.scheduler.get_lr()[0]))

    def handle_new_epoch(self):

        self.batch_cnt = 1

        lr = self.scheduler.get_lr()[0]
        self.logger.info('Epoch {} done'.format(self.train_epoch))
        self.logger.info('-' * 20)

        torch.save(
            self.model.state_dict(),
            osp.join(self.output_dir, self.cfg.MODEL.NAME + '_epoch_last.pth'))
        torch.save(
            self.optim.state_dict(),
            osp.join(self.output_dir,
                     self.cfg.MODEL.NAME + '_epoch_last_optim.pth'))

        if self.train_epoch > self.cfg.SOLVER.START_SAVE_EPOCH and self.train_epoch % self.checkpoint_period == 0:
            self.save()
        if (self.train_epoch > 0 and self.train_epoch % self.eval_period
                == 0) or self.train_epoch == 50:
            self.evaluate()
            pass
        self.scheduler.step()
        self.train_epoch += 1

    # sample negative example for ce and tpl loss by mixup
    def posneg_mixup(self,
                     imgs,
                     targets,
                     num_instance,
                     neg_instance,
                     alpha=0.75):
        sample_imgs = []
        sample_targets1 = []
        sample_targets2 = []
        sample_lambdas = []

        lamb = np.random.beta(alpha, alpha)

        for p in range(self.cfg.SOLVER.IMS_PER_BATCH // num_instance):
            # lamb = np.random.beta(alpha, alpha)
            # lambs = [lamb for i in range(neg_instance)]

            ps = [
                i for i in range(self.cfg.SOLVER.IMS_PER_BATCH)
                if i // num_instance != p
            ]
            ps = np.random.choice(ps, size=neg_instance * 2, replace=True)
            ps = ps.reshape((2, -1))

            sample_imgs.append(lamb * imgs[ps[0]] + (1 - lamb) * imgs[ps[1]])
            sample_targets1.append(targets[ps[0]])
            sample_targets2.append(targets[ps[1]])

            # sample_lambdas.extend(lambs)
        # return torch.cat(sample_imgs,dim=0),torch.cat(sample_targets1,dim=0),torch.cat(sample_targets2,dim=0),sample_lambdas
        return torch.cat(sample_imgs,
                         dim=0), torch.cat(sample_targets1,
                                           dim=0), torch.cat(sample_targets2,
                                                             dim=0), lamb

    def step(self, batch):
        self.model.train()
        self.optim.zero_grad()
        #
        img, target = batch
        img, target = img.cuda(), target.cuda()
        outputs = self.model(img)
        #
        if self.cfg.SOLVER.MIXUP.USE:

            mx_img, mx_target1, mx_target2, lamb = self.posneg_mixup(
                img, target, self.cfg.DATALOADER.NUM_INSTANCE,
                self.cfg.SOLVER.MIXUP.NEG_INSTANCE,
                self.cfg.SOLVER.MIXUP.ALPHA)
            mx_outputs = self.model(mx_img)
            loss = self.loss_func(outputs, target, mx_outputs, mx_target1,
                                  mx_target2, lamb)

        else:
            loss = self.loss_func(outputs, target)

        if self.mix_precision:
            with amp.scale_loss(loss, self.optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.optim.step()

        # acc = (score.max(1)[1] == target).float().mean()
        acc = calculate_acc(self.cfg, outputs, target)

        self.loss_avg.update(loss.cpu().item())
        self.acc_avg.update(acc.cpu().item())

        return self.loss_avg.avg, self.acc_avg.avg

    def evaluate(self):
        self.model.eval()
        num_query = self.num_query
        feats, pids, camids = [], [], []
        with torch.no_grad():
            for batch in tqdm(self.val_dl, total=len(self.val_dl),
                              leave=False):
                data, pid, camid, _ = batch
                data = data.cuda()

                # ff = torch.FloatTensor(data.size(0), 2048).zero_()
                # for i in range(2):
                #     if i == 1:
                #         data = data.index_select(3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda'))
                #     outputs = self.model(data)
                #     f = outputs.data.cpu()
                #     ff = ff + f

                ff = self.model(data).data.cpu()
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                ff = ff.div(fnorm.expand_as(ff))

                feats.append(ff)
                pids.append(pid)
                camids.append(camid)
        feats = torch.cat(feats, dim=0)
        pids = torch.cat(pids, dim=0)
        camids = torch.cat(camids, dim=0)

        query_feat = feats[:num_query]
        query_pid = pids[:num_query]
        query_camid = camids[:num_query]

        gallery_feat = feats[num_query:]
        gallery_pid = pids[num_query:]
        gallery_camid = camids[num_query:]

        distmat = euclidean_dist(query_feat, gallery_feat)

        cmc, mAP, _ = eval_func(
            distmat.numpy(),
            query_pid.numpy(),
            gallery_pid.numpy(),
            query_camid.numpy(),
            gallery_camid.numpy(),
        )
        self.logger.info('Validation Result:')
        self.logger.info('mAP: {:.2%}'.format(mAP))
        for r in self.cfg.TEST.CMC:
            self.logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))

        self.logger.info('average of mAP and rank1: {:.2%}'.format(
            (mAP + cmc[0]) / 2.0))

        self.logger.info('-' * 20)

    def save(self):
        torch.save(
            self.model.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '.pth'))
        torch.save(
            self.optim.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '_optim.pth'))
コード例 #2
0
class HistLabelTrainer(object):
    def __init__(self, cfg, model, train_dl, val_dl, loss_func, num_query,
                 num_gpus):
        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.loss_func = loss_func
        self.num_query = num_query

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.train_epoch = 1
        self.batch_cnt = 0

        self.logger = logging.getLogger('reid_baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR
        self.device = cfg.MODEL.DEVICE
        self.epochs = cfg.SOLVER.MAX_EPOCHS

        if cfg.SOLVER.TENSORBOARD.USE:
            summary_dir = os.path.join(cfg.OUTPUT_DIR, 'summaries/')
            os.makedirs(summary_dir, exist_ok=True)
            self.summary_writer = SummaryWriter(log_dir=summary_dir)
        self.current_iteration = 0

        self.model.cuda()
        self.logger.info(self.model)

        if num_gpus > 1:

            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_EPOCH,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.optim)

            self.mix_precision = (cfg.MODEL.OPT_LEVEL != "O0")
            if self.mix_precision:
                self.model, self.optim = amp.initialize(
                    self.model, self.optim, opt_level=cfg.MODEL.OPT_LEVEL)
                self.logger.info(
                    'Using apex for mix_precision with opt_level {}'.format(
                        cfg.MODEL.OPT_LEVEL))

            self.model = nn.DataParallel(self.model)
            if cfg.SOLVER.SYNCBN:
                if self.mix_precision:
                    self.model = apex.parallel.convert_syncbn_model(self.model)
                    self.logger.info(
                        'More than one gpu used, convert model to use SyncBN.')
                    self.logger.info('Using apex SyncBN implementation')
                else:
                    self.model = convert_model(self.model)
                    self.model.cuda()
                    self.logger.info(
                        'More than one gpu used, convert model to use SyncBN.')
                    self.logger.info('Using pytorch SyncBN implementation')
                    self.logger.info(self.model)

            self.logger.info('Trainer Built')

            return

        else:
            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_EPOCH,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.optim)

            self.mix_precision = (cfg.MODEL.OPT_LEVEL != "O0")
            if self.mix_precision:
                self.model, self.optim = amp.initialize(
                    self.model, self.optim, opt_level=cfg.MODEL.OPT_LEVEL)
                self.logger.info(
                    'Using apex for mix_precision with opt_level {}'.format(
                        cfg.MODEL.OPT_LEVEL))

            return

    def handle_new_batch(self):
        if self.current_iteration % self.cfg.SOLVER.TENSORBOARD.LOG_PERIOD == 0:
            if self.summary_writer:
                self.summary_writer.add_scalar('Train/lr',
                                               self.scheduler.get_lr()[0],
                                               self.current_iteration)
                self.summary_writer.add_scalar('Train/loss', self.loss_avg.avg,
                                               self.current_iteration)
                self.summary_writer.add_scalar('Train/acc', self.acc_avg.avg,
                                               self.current_iteration)

        self.batch_cnt += 1
        self.current_iteration += 1
        if self.batch_cnt % self.cfg.SOLVER.LOG_PERIOD == 0:
            self.logger.info('Epoch[{}] Iteration[{}/{}] Loss: {:.3f},'
                             'Acc: {:.3f}, Base Lr: {:.2e}'.format(
                                 self.train_epoch, self.batch_cnt,
                                 len(self.train_dl), self.loss_avg.avg,
                                 self.acc_avg.avg,
                                 self.scheduler.get_lr()[0]))

    def handle_new_epoch(self):

        self.batch_cnt = 1

        lr = self.scheduler.get_lr()[0]
        self.logger.info('Epoch {} done'.format(self.train_epoch))
        self.logger.info('-' * 20)

        torch.save(
            self.model.state_dict(),
            osp.join(self.output_dir, self.cfg.MODEL.NAME + '_epoch_last.pth'))
        torch.save(
            self.optim.state_dict(),
            osp.join(self.output_dir,
                     self.cfg.MODEL.NAME + '_epoch_last_optim.pth'))

        if self.train_epoch > self.cfg.SOLVER.START_SAVE_EPOCH and self.train_epoch % self.checkpoint_period == 0:
            self.save()
        if (self.train_epoch > 0 and self.train_epoch % self.eval_period
                == 0) or self.train_epoch == 50:
            self.evaluate()
            pass
        self.scheduler.step()
        self.train_epoch += 1

    def step(self, batch):
        self.model.train()
        self.optim.zero_grad()
        img, target, histlabels = batch
        img, target, histlabels = img.cuda(), target.cuda(), histlabels.cuda()
        outputs = self.model(img)

        loss, tpl, ce, hlce = self.loss_func(outputs,
                                             target,
                                             histlabels,
                                             in_detail=True)

        if self.current_iteration % self.cfg.SOLVER.TENSORBOARD.LOG_PERIOD == 0:
            if self.summary_writer:
                self.summary_writer.add_scalar('Train/tpl', tpl,
                                               self.current_iteration)
                self.summary_writer.add_scalar('Train/ce', ce,
                                               self.current_iteration)
                self.summary_writer.add_scalar('Train/hlce', hlce,
                                               self.current_iteration)

        if self.mix_precision:
            with amp.scale_loss(loss, self.optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.optim.step()

        # acc = (score.max(1)[1] == target).float().mean()
        acc = calculate_acc(self.cfg, outputs, target)

        self.loss_avg.update(loss.cpu().item())
        self.acc_avg.update(acc.cpu().item())

        return self.loss_avg.avg, self.acc_avg.avg

    def evaluate(self):
        self.model.eval()
        num_query = self.num_query
        feats, pids, camids = [], [], []
        histlabels = []
        histpreds = []
        with torch.no_grad():
            for batch in tqdm(self.val_dl, total=len(self.val_dl),
                              leave=False):
                data, pid, camid, _, histlabel = batch
                data = data.cuda()
                # histlabel = histlabel.cuda()

                # ff = torch.FloatTensor(data.size(0), 2048).zero_()
                # for i in range(2):
                #     if i == 1:
                #         data = data.index_select(3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda'))
                #     outputs = self.model(data)
                #     f = outputs.data.cpu()
                #     ff = ff + f

                ff, histpred = self.model(data,
                                          output_feature='with_histlabel')
                ff = ff.data.cpu()
                histpred = histpred.data.cpu()
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                ff = ff.div(fnorm.expand_as(ff))

                feats.append(ff)
                pids.append(pid)
                camids.append(camid)
                histlabels.append(histlabel)
                histpreds.append(histpred)
        feats = torch.cat(feats, dim=0)
        pids = torch.cat(pids, dim=0)
        camids = torch.cat(camids, dim=0)
        histpreds = torch.cat(histpreds, dim=0)
        histlabels = torch.cat(histlabels, dim=0)

        hist_acc = (histpreds[:histlabels.size()[0]].max(1)[1] == histlabels
                    ).float().mean().item()

        if self.cfg.TEST.RANDOMPERM <= 0:
            query_feat = feats[:num_query]
            query_pid = pids[:num_query]
            query_camid = camids[:num_query]

            gallery_feat = feats[num_query:]
            gallery_pid = pids[num_query:]
            gallery_camid = camids[num_query:]

            distmat = euclidean_dist(query_feat, gallery_feat)

            cmc, mAP, _ = eval_func(
                distmat.numpy(),
                query_pid.numpy(),
                gallery_pid.numpy(),
                query_camid.numpy(),
                gallery_camid.numpy(),
            )
        else:
            cmc = 0
            mAP = 0
            seed = torch.random.get_rng_state()
            torch.manual_seed(0)
            for i in range(self.cfg.TEST.RANDOMPERM):
                index = torch.randperm(feats.size()[0])
                # print(index[:10])
                query_feat = feats[index][:num_query]
                query_pid = pids[index][:num_query]
                query_camid = camids[index][:num_query]

                gallery_feat = feats[index][num_query:]
                gallery_pid = pids[index][num_query:]
                gallery_camid = camids[index][num_query:]

                distmat = euclidean_dist(query_feat, gallery_feat)

                _cmc, _mAP, _ = eval_func(
                    distmat.numpy(),
                    query_pid.numpy(),
                    gallery_pid.numpy(),
                    query_camid.numpy(),
                    gallery_camid.numpy(),
                )
                cmc += _cmc / self.cfg.TEST.RANDOMPERM
                mAP += _mAP / self.cfg.TEST.RANDOMPERM
            torch.random.set_rng_state(seed)

        self.logger.info('Validation Result:')
        self.logger.info('hist acc:{:.2%}'.format(hist_acc))
        self.logger.info('mAP: {:.2%}'.format(mAP))
        for r in self.cfg.TEST.CMC:
            self.logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))

        self.logger.info('average of mAP and rank1: {:.2%}'.format(
            (mAP + cmc[0]) / 2.0))
        self.logger.info('-' * 20)

        if self.summary_writer:
            self.summary_writer.add_scalar('Valid/hist_acc', hist_acc,
                                           self.train_epoch)
            self.summary_writer.add_scalar('Valid/rank1', cmc[0],
                                           self.train_epoch)
            self.summary_writer.add_scalar('Valid/mAP', mAP, self.train_epoch)
            self.summary_writer.add_scalar('Valid/rank1_mAP',
                                           (mAP + cmc[0]) / 2.0,
                                           self.train_epoch)

    def save(self):
        torch.save(
            self.model.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '.pth'))
        torch.save(
            self.optim.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '_optim.pth'))
コード例 #3
0
class ExemplarMemoryTrainer(object):
    def __init__(self, cfg, model, train_dl, val_dl, exemplar_dl, loss_func,
                 num_query, num_gpus):
        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.exemplar_dl = exemplar_dl
        self.loss_func = loss_func
        self.num_query = num_query

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.train_epoch = 1
        self.batch_cnt = 0

        self.logger = logging.getLogger('reid_baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR
        self.device = cfg.MODEL.DEVICE
        self.epochs = cfg.SOLVER.MAX_EPOCHS

        self.model.cuda()
        self.logger.info(self.model)
        # ex memory
        self.exemplar_memory = ExemplarMemoryLoss(
            cfg.DATASETS.EXEMPLAR.MEMORY.NUM_FEATS,
            len(exemplar_dl.dataset),
            beta=cfg.DATASETS.EXEMPLAR.MEMORY.BETA,
            knn=cfg.DATASETS.EXEMPLAR.MEMORY.KNN,
            alpha=cfg.DATASETS.EXEMPLAR.MEMORY.ALPHA,
            knn_start_epoch=cfg.DATASETS.EXEMPLAR.MEMORY.KNN_START_EPOCH)
        self.exemplar_memory.cuda()
        self.logger.info(self.exemplar_memory)
        # Target iter
        self.exemplar_iter = iter(exemplar_dl)

        if num_gpus > 1:

            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_EPOCH,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.optim)

            self.mix_precision = (cfg.MODEL.OPT_LEVEL != "O0")
            if self.mix_precision:
                self.model, self.optim = amp.initialize(
                    self.model, self.optim, opt_level=cfg.MODEL.OPT_LEVEL)
                self.logger.info(
                    'Using apex for mix_precision with opt_level {}'.format(
                        cfg.MODEL.OPT_LEVEL))

            self.model = nn.DataParallel(self.model)
            if cfg.SOLVER.SYNCBN:
                if self.mix_precision:
                    self.model = apex.parallel.convert_syncbn_model(self.model)
                    self.logger.info(
                        'More than one gpu used, convert model to use SyncBN.')
                    self.logger.info('Using apex SyncBN implementation')
                else:
                    self.model = convert_model(self.model)
                    self.logger.info(
                        'More than one gpu used, convert model to use SyncBN.')
                    self.logger.info('Using pytorch SyncBN implementation')

            self.logger.info('Trainer Built')

            return

        else:

            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_EPOCH,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.optim)

            self.mix_precision = (cfg.MODEL.OPT_LEVEL != "O0")
            if self.mix_precision:
                self.model, self.optim = amp.initialize(
                    self.model, self.optim, opt_level=cfg.MODEL.OPT_LEVEL)
                self.logger.info(
                    'Using apex for mix_precision with opt_level {}'.format(
                        cfg.MODEL.OPT_LEVEL))

            return

    def handle_new_batch(self):
        self.batch_cnt += 1
        if self.batch_cnt % self.cfg.SOLVER.LOG_PERIOD == 0:
            self.logger.info('Epoch[{}] Iteration[{}/{}] Loss: {:.3f},'
                             'Acc: {:.3f}, Base Lr: {:.2e}'.format(
                                 self.train_epoch, self.batch_cnt,
                                 len(self.train_dl), self.loss_avg.avg,
                                 self.acc_avg.avg,
                                 self.scheduler.get_lr()[0]))

    def handle_new_epoch(self):

        self.batch_cnt = 1

        lr = self.scheduler.get_lr()[0]
        self.logger.info('Epoch {} done'.format(self.train_epoch))
        self.logger.info('-' * 20)

        torch.save(
            self.model.state_dict(),
            osp.join(self.output_dir, self.cfg.MODEL.NAME + '_epoch_last.pth'))
        torch.save(
            self.optim.state_dict(),
            osp.join(self.output_dir,
                     self.cfg.MODEL.NAME + '_epoch_last_optim.pth'))

        if self.train_epoch > self.cfg.SOLVER.START_SAVE_EPOCH and self.train_epoch % self.checkpoint_period == 0:
            self.save()
        if (self.train_epoch > 0 and self.train_epoch % self.eval_period
                == 0) or self.train_epoch == 50:
            self.evaluate()
            pass
        self.scheduler.step()
        self.train_epoch += 1

    def step(self, batch):
        self.model.train()
        self.optim.zero_grad()
        img, target = batch
        img, target = img.cuda(), target.cuda()
        #
        # Target inputs
        try:
            inputs_exemplar = next(self.exemplar_iter)
        except:
            self.exemplar_iter = iter(self.exemplar_dl)
            inputs_exemplar = next(self.exemplar_iter)

        img_exemplar, target_exemplar = inputs_exemplar
        img_exemplar, target_exemplar = img_exemplar.cuda(
        ), target_exemplar.cuda()
        # source
        outputs = self.model(img)
        loss = self.loss_func(outputs, target)
        #
        exemplar_outputs = self.model(img_exemplar, 'exemplar_feat')
        loss_un = self.exemplar_memory(exemplar_outputs,
                                       target_exemplar,
                                       epoch=self.train_epoch)

        loss = (1 - self.cfg.DATASETS.EXEMPLAR.MEMORY.LAMBDA
                ) * loss + self.cfg.DATASETS.EXEMPLAR.MEMORY.LAMBDA * loss_un

        if self.mix_precision:
            with amp.scale_loss(loss, self.optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.optim.step()

        # acc = (score.max(1)[1] == target).float().mean()
        acc = calculate_acc(self.cfg, outputs, target)

        self.loss_avg.update(loss.cpu().item())
        self.acc_avg.update(acc.cpu().item())

        return self.loss_avg.avg, self.acc_avg.avg

    def evaluate(self):
        self.model.eval()
        num_query = self.num_query
        feats, pids, camids = [], [], []
        with torch.no_grad():
            for batch in tqdm(self.val_dl, total=len(self.val_dl),
                              leave=False):
                data, pid, camid, _ = batch
                data = data.cuda()

                # ff = torch.FloatTensor(data.size(0), 2048).zero_()
                # for i in range(2):
                #     if i == 1:
                #         data = data.index_select(3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda'))
                #     outputs = self.model(data)
                #     f = outputs.data.cpu()
                #     ff = ff + f

                ff = self.model(data).data.cpu()
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                ff = ff.div(fnorm.expand_as(ff))

                feats.append(ff)
                pids.append(pid)
                camids.append(camid)
        feats = torch.cat(feats, dim=0)
        pids = torch.cat(pids, dim=0)
        camids = torch.cat(camids, dim=0)

        query_feat = feats[:num_query]
        query_pid = pids[:num_query]
        query_camid = camids[:num_query]

        gallery_feat = feats[num_query:]
        gallery_pid = pids[num_query:]
        gallery_camid = camids[num_query:]

        distmat = euclidean_dist(query_feat, gallery_feat)

        cmc, mAP, _ = eval_func(
            distmat.numpy(),
            query_pid.numpy(),
            gallery_pid.numpy(),
            query_camid.numpy(),
            gallery_camid.numpy(),
        )
        self.logger.info('Validation Result:')
        self.logger.info('mAP: {:.2%}'.format(mAP))
        for r in self.cfg.TEST.CMC:
            self.logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))

        self.logger.info('average of mAP and rank1: {:.2%}'.format(
            (mAP + cmc[0]) / 2.0))

        self.logger.info('-' * 20)

    def save(self):
        torch.save(
            self.model.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '.pth'))
        torch.save(
            self.optim.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '_optim.pth'))
コード例 #4
0
ファイル: sgd_trainer.py プロジェクト: SZLSP/naic_reid_2019
class SGDTrainer(object):
    def __init__(self, cfg, model, train_dl, val_dl, loss_func, num_query,
                 num_gpus):
        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.loss_func = loss_func
        self.num_query = num_query

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.train_epoch = 1
        self.batch_cnt = 0

        self.logger = logging.getLogger('reid_baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR
        self.device = cfg.MODEL.DEVICE
        self.epochs = cfg.SOLVER.MAX_EPOCHS

        if num_gpus > 1:
            # convert to use sync_bn
            self.logger.info(
                'More than one gpu used, convert model to use SyncBN.')
            # Multi-GPU model without FP16
            self.model = nn.DataParallel(self.model)
            self.model = convert_model(self.model)
            self.model.cuda()
            self.logger.info('Using pytorch SyncBN implementation')

            self.scheduler = LRScheduler(
                base_lr=cfg.SOLVER.BASE_LR,
                step=cfg.SOLVER.STEPS,
                factor=cfg.SOLVER.GAMMA,
                warmup_epoch=cfg.SOLVER.WARMUP_EPOCH,
                warmup_begin_lr=cfg.SOLVER.WARMUP_BEGAIN_LR,
                warmup_mode=cfg.SOLVER.WARMUP_METHOD)
            lr = self.scheduler.update(0)
            self.optim = optim.SGD(self.model.parameters(),
                                   lr=lr,
                                   weight_decay=5e-4,
                                   momentum=0.9,
                                   nesterov=True)
            # self.optim = make_optimizer(self.model,opt=self.cfg.SOLVER.OPTIMIZER_NAME,lr=lr,weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,momentum=0.9)

            self.mix_precision = False
            self.logger.info(self.optim)
            self.logger.info('Trainer Built')
            return

        else:
            # Single GPU model
            self.model.cuda()

            self.scheduler = LRScheduler(
                base_lr=cfg.SOLVER.BASE_LR,
                step=cfg.SOLVER.STEPS,
                factor=cfg.SOLVER.GAMMA,
                warmup_epoch=cfg.SOLVER.WARMUP_EPOCH,
                warmup_begin_lr=cfg.SOLVER.WARMUP_BEGAIN_LR,
                warmup_mode=cfg.SOLVER.WARMUP_METHOD)
            lr = self.scheduler.update(0)
            self.optim = optim.SGD(self.model.parameters(),
                                   lr=lr,
                                   weight_decay=5e-4,
                                   momentum=0.9,
                                   nesterov=True)
            # self.optim = make_optimizer(self.model,opt=self.cfg.SOLVER.OPTIMIZER_NAME,lr=lr,weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,momentum=0.9)
            self.logger.info(self.optim)
            self.mix_precision = False
            return

    def handle_new_batch(self):
        self.batch_cnt += 1
        if self.batch_cnt % self.cfg.SOLVER.LOG_PERIOD == 0:
            self.logger.info('Epoch[{}] Iteration[{}/{}] Loss: {:.3f},'
                             'Acc: {:.3f}, Base Lr: {:.2e}'.format(
                                 self.train_epoch, self.batch_cnt,
                                 len(self.train_dl), self.loss_avg.avg,
                                 self.acc_avg.avg,
                                 self.scheduler.learning_rate))

    def handle_new_epoch(self):

        self.batch_cnt = 1
        lr = self.scheduler.update(self.train_epoch)
        self.optim = optim.SGD(self.model.parameters(),
                               lr=lr,
                               weight_decay=5e-4,
                               momentum=0.9,
                               nesterov=True)
        # self.optim = make_optimizer(self.model,opt=self.cfg.SOLVER.OPTIMIZER_NAME,lr=lr,weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,momentum=0.9)
        for param_group in self.optim.param_groups:
            param_group['lr'] = lr
        self.logger.info('Epoch {} done'.format(self.train_epoch))
        self.logger.info('-' * 20)

        torch.save(
            self.model.state_dict(),
            osp.join(self.output_dir, self.cfg.MODEL.NAME + '_epoch_last.pth'))
        torch.save(
            self.optim.state_dict(),
            osp.join(self.output_dir,
                     self.cfg.MODEL.NAME + '_epoch_last_optim.pth'))

        if self.train_epoch > self.cfg.SOLVER.START_SAVE_EPOCH and self.train_epoch % self.checkpoint_period == 0:
            self.save()
        if (self.train_epoch > 0 and self.train_epoch % self.eval_period
                == 0) or self.train_epoch == 50:
            self.evaluate()
            pass
        self.train_epoch += 1

    def step(self, batch):
        self.model.train()
        self.optim.zero_grad()
        img, target = batch
        img, target = img.cuda(), target.cuda()
        outputs = self.model(img)

        loss = self.loss_func(outputs, target)

        if self.mix_precision:
            with amp.scale_loss(loss, self.optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.optim.step()

        # acc = (score.max(1)[1] == target).float().mean()
        acc = calculate_acc(self.cfg, outputs, target)

        self.loss_avg.update(loss.cpu().item())
        self.acc_avg.update(acc.cpu().item())

        return self.loss_avg.avg, self.acc_avg.avg

    def evaluate(self):
        self.model.eval()
        num_query = self.num_query
        feats, pids, camids = [], [], []
        with torch.no_grad():
            for batch in tqdm(self.val_dl, total=len(self.val_dl),
                              leave=False):
                data, pid, camid, _ = batch
                data = data.cuda()

                # ff = torch.FloatTensor(data.size(0), 2048).zero_()
                # for i in range(2):
                #     if i == 1:
                #         data = data.index_select(3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda'))
                #     outputs = self.model(data)
                #     f = outputs.data.cpu()
                #     ff = ff + f

                ff = self.model(data).data.cpu()
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                ff = ff.div(fnorm.expand_as(ff))

                feats.append(ff)
                pids.append(pid)
                camids.append(camid)
        feats = torch.cat(feats, dim=0)
        pids = torch.cat(pids, dim=0)
        camids = torch.cat(camids, dim=0)

        query_feat = feats[:num_query]
        query_pid = pids[:num_query]
        query_camid = camids[:num_query]

        gallery_feat = feats[num_query:]
        gallery_pid = pids[num_query:]
        gallery_camid = camids[num_query:]

        distmat = euclidean_dist(query_feat, gallery_feat)

        cmc, mAP, _ = eval_func(
            distmat.numpy(),
            query_pid.numpy(),
            gallery_pid.numpy(),
            query_camid.numpy(),
            gallery_camid.numpy(),
        )
        self.logger.info('Validation Result:')
        self.logger.info('mAP: {:.2%}'.format(mAP))
        for r in self.cfg.TEST.CMC:
            self.logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))

        self.logger.info('average of mAP and rank1: {:.2%}'.format(
            (mAP + cmc[0]) / 2.0))

        self.logger.info('-' * 20)

    def save(self):
        torch.save(
            self.model.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '.pth'))
        torch.save(
            self.optim.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '_optim.pth'))
コード例 #5
0
ファイル: trainer.py プロジェクト: samsgood0310/Tricks
class BaseTrainer(object):
    def __init__(self, cfg, model, train_dl, val_dl, loss_func, num_gpus,
                 device):

        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.loss_func = loss_func

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.f1_avg = AvgerageMeter()

        self.val_loss_avg = AvgerageMeter()
        self.val_acc_avg = AvgerageMeter()
        self.device = device

        self.train_epoch = 1

        if cfg.SOLVER.USE_WARMUP:
            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR * 0.1,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
        else:
            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
        if cfg.SOLVER.RESUME:
            print("Resume from checkpoint...")
            checkpoint = torch.load(cfg.SOLVER.RESUME_CHECKPOINT)
            param_dict = checkpoint['model_state_dict']
            self.optim.load_state_dict(checkpoint['optimizer_state_dict'])
            for state in self.optim.state.values():
                for k, v in state.items():
                    print(type(v))
                    if torch.is_tensor(v):
                        state[k] = v.to(self.device)
            self.train_epoch = checkpoint['epoch'] + 1
            for i in param_dict:
                if i.startswith("module"):
                    new_i = i[7:]
                else:
                    new_i = i
                if 'classifier' in i or 'fc' in i:
                    continue
                self.model.state_dict()[new_i].copy_(param_dict[i])

        self.batch_cnt = 0

        self.logger = logging.getLogger('baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR

        self.epochs = cfg.SOLVER.MAX_EPOCHS

        if cfg.SOLVER.TENSORBOARD.USE:
            summary_dir = os.path.join(cfg.OUTPUT_DIR, 'summaries/')
            os.makedirs(summary_dir, exist_ok=True)
            self.summary_writer = SummaryWriter(log_dir=summary_dir)
        self.current_iteration = 0

        self.logger.info(self.model)

        if self.cfg.SOLVER.USE_WARMUP:

            scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optim, self.epochs, eta_min=cfg.SOLVER.MIN_LR)
            self.scheduler = GradualWarmupScheduler(
                self.optim,
                multiplier=10,
                total_epoch=cfg.SOLVER.WARMUP_EPOCH,
                after_scheduler=scheduler_cosine)
            # self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
            #                                cfg.SOLVER.WARMUP_EPOCH, cfg.SOLVER.WARMUP_METHOD)
        else:
            self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optim, self.epochs, eta_min=cfg.SOLVER.MIN_LR)

        if num_gpus > 1:

            self.logger.info(self.optim)
            self.model = nn.DataParallel(self.model)
            if cfg.SOLVER.SYNCBN:
                self.model = convert_model(self.model)
                self.model = self.model.to(device)
                self.logger.info(
                    'More than one gpu used, convert model to use SyncBN.')
                self.logger.info('Using pytorch SyncBN implementation')
                self.logger.info(self.model)

            self.logger.info('Trainer Built')

            return

        else:
            self.model = self.model.to(device)
            self.logger.info('Cpu used.')
            self.logger.info(self.model)
            self.logger.info('Trainer Built')

            return

    def handle_new_batch(self):

        lr = self.scheduler.get_lr()[0]
        if self.current_iteration % self.cfg.SOLVER.TENSORBOARD.LOG_PERIOD == 0:
            if self.summary_writer:
                self.summary_writer.add_scalar('Train/lr', lr,
                                               self.current_iteration)
                self.summary_writer.add_scalar('Train/loss', self.loss_avg.avg,
                                               self.current_iteration)
                self.summary_writer.add_scalar('Train/acc', self.acc_avg.avg,
                                               self.current_iteration)
                self.summary_writer.add_scalar('Train/f1', self.f1_avg.avg,
                                               self.current_iteration)

        self.batch_cnt += 1
        self.current_iteration += 1
        if self.batch_cnt % self.cfg.SOLVER.LOG_PERIOD == 0:

            self.logger.info('Epoch[{}] Iteration[{}/{}] Loss: {:.3f},'
                             'acc: {:.3f}, f1: {:.3f}, Base Lr: {:.2e}'.format(
                                 self.train_epoch, self.batch_cnt,
                                 len(self.train_dl), self.loss_avg.avg,
                                 self.acc_avg.avg, self.f1_avg.avg, lr))

    def handle_new_epoch(self):

        self.batch_cnt = 1

        self.logger.info('Epoch {} done'.format(self.train_epoch))
        self.logger.info('-' * 20)
        checkpoint = {
            'epoch': self.train_epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optim.state_dict(),
        }
        torch.save(
            checkpoint,
            osp.join(self.output_dir, self.cfg.MODEL.NAME + '_epoch_last.pth'))
        # torch.save(self.optim.state_dict(), osp.join(self.output_dir,
        #                                              self.cfg.MODEL.NAME +"_k_f"+str(self.cfg.DATALOADER.VAL_FOLDER)+  '_epoch_last_optim.pth'))

        if self.train_epoch > self.cfg.SOLVER.START_SAVE_EPOCH and self.train_epoch % self.checkpoint_period == 0:
            self.save()
        if (self.train_epoch > 0 and self.train_epoch % self.eval_period
                == 0) or self.train_epoch == 1:
            self.evaluate()

        self.acc_avg.reset()
        self.f1_avg.reset()
        self.loss_avg.reset()
        self.val_loss_avg.reset()

        self.scheduler.step()
        self.train_epoch += 1

    def step(self, batch):
        self.model.train()
        self.optim.zero_grad()
        data, target = batch
        data, target = data.to(self.device), target.to(self.device)

        if self.cfg.INPUT.USE_MIX_UP:
            data, target_a, target_b, lam = mixup_data(data, target, 0.4, True)
        self.use_cut_mix = False
        if self.cfg.INPUT.USE_RICAP:
            I_x, I_y = input.size()[2:]

            w = int(
                np.round(I_x *
                         np.random.beta(args.ricap_beta, args.ricap_beta)))
            h = int(
                np.round(I_y *
                         np.random.beta(args.ricap_beta, args.ricap_beta)))
            w_ = [w, I_x - w, w, I_x - w]
            h_ = [h, h, I_y - h, I_y - h]

            cropped_images = {}
            c_ = {}
            W_ = {}
            for k in range(4):
                idx = torch.randperm(input.size(0))
                x_k = np.random.randint(0, I_x - w_[k] + 1)
                y_k = np.random.randint(0, I_y - h_[k] + 1)
                cropped_images[k] = input[idx][:, :, x_k:x_k + w_[k],
                                               y_k:y_k + h_[k]]
                c_[k] = target[idx].cuda()
                W_[k] = w_[k] * h_[k] / (I_x * I_y)

            patched_images = torch.cat((torch.cat(
                (cropped_images[0], cropped_images[1]),
                2), torch.cat((cropped_images[2], cropped_images[3]), 2)), 3)
            data = patched_images.to(self.device)

        if self.cfg.INPUT.USE_CUT_MIX:
            r = np.random.rand(1)
            if r < 0.5:
                self.use_cut_mix = True
                lam = np.random.beta(1.0, 1.0)
                rand_index = torch.randperm(data.size()[0]).cuda()
                target_a = target
                target_b = target[rand_index]
                bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
                data[:, :, bbx1:bbx2, bby1:bby2] = data[rand_index, :,
                                                        bbx1:bbx2, bby1:bby2]
                # adjust lambda to exactly match pixel ratio
                lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                           (data.size()[-1] * data.size()[-2]))
                # compute output
                data = torch.autograd.Variable(data, requires_grad=True)
                target_a_var = torch.autograd.Variable(target_a)
                target_b_var = torch.autograd.Variable(target_b)

        outputs = self.model(data)

        # loss = self.loss_func(outputs, target)
        if self.cfg.INPUT.USE_RICAP:
            loss = sum(
                [W_[k] * self.loss_func(outputs, c_[k]) for k in range(4)])
        elif self.cfg.INPUT.USE_MIX_UP:
            loss1 = self.loss_func(outputs, target_a)
            loss2 = self.loss_func(outputs, target_b)
            loss = lam * loss1 + (1 - lam) * loss2
        elif self.cfg.INPUT.USE_CUT_MIX and self.use_cut_mix:
            loss1 = self.loss_func(outputs, target_a_var)
            loss2 = self.loss_func(outputs, target_b_var)
            loss = lam * loss1 + (1 - lam) * loss2
        else:
            loss = self.loss_func(outputs, target)

        if self.current_iteration % self.cfg.SOLVER.TENSORBOARD.LOG_PERIOD == 0:
            if self.summary_writer:
                self.summary_writer.add_scalar('Train/loss', loss,
                                               self.current_iteration)
        loss.backward()
        self.optim.step()

        if type(outputs) == type(()) and len(outputs) > 1:
            _output = outputs[0]
            for output in outputs[1:]:
                _output = _output + output
            outputs = _output / len(outputs)

        target = target.data.cpu()
        outputs = outputs.data.cpu()

        f1, acc = calculate_score(self.cfg, outputs, target)

        self.loss_avg.update(loss.cpu().item())
        self.acc_avg.update(acc)
        self.f1_avg.update(f1)

        return self.loss_avg.avg, self.acc_avg.avg, self.f1_avg.avg

    def evaluate(self):
        self.model.eval()
        print(len(self.val_dl))

        with torch.no_grad():

            all_outputs = list()
            all_targets = list()

            for batch in tqdm(self.val_dl, total=len(self.val_dl),
                              leave=False):
                data, target = batch
                data = data.to(self.device)
                target = target.to(self.device)
                outputs = self.model(data)
                loss = self.loss_func(outputs, target)
                if type(outputs) == type(()) and len(outputs) > 1:
                    _output = outputs[0]
                    for output in outputs:
                        _output = _output + output
                    outputs = _output / len(outputs)
                target = target.data.cpu()
                outputs = outputs.data.cpu()

                self.val_loss_avg.update(loss.cpu().item())

                all_outputs.append(outputs)
                all_targets.append(target)

            all_outputs = torch.cat(all_outputs, 0)
            all_targets = torch.cat(all_targets, 0)

        val_f1, val_acc = calculate_score(self.cfg, all_outputs, all_targets)

        self.logger.info('Validation Result:')

        self.logger.info('VAL_LOSS: %s, VAL_ACC: %s VAL_F1: %s \n' %
                         (self.val_loss_avg.avg, val_acc, val_f1))

        self.logger.info('-' * 20)

        if self.summary_writer:

            self.summary_writer.add_scalar('Valid/loss', self.val_loss_avg.avg,
                                           self.train_epoch)
            self.summary_writer.add_scalar('Valid/acc', np.mean(val_acc),
                                           self.train_epoch)
            self.summary_writer.add_scalar('Valid/f1', np.mean(val_f1),
                                           self.train_epoch)

    def save(self):
        torch.save(
            self.model.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '.pth'))
        torch.save(
            self.optim.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '_optim.pth'))
コード例 #6
0
class BaseTrainer(object):
    def __init__(self, cfg, model, train_dl, val_dl, loss_func, num_query,
                 num_gpus):
        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.loss_func = loss_func
        self.num_query = num_query

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.train_epoch = 1
        self.batch_cnt = 0

        self.logger = logging.getLogger('reid_baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR
        self.device = cfg.MODEL.DEVICE
        self.epochs = cfg.SOLVER.MAX_EPOCHS

        if num_gpus > 1:
            # convert to use sync_bn
            self.logger.info(
                'More than one gpu used, convert model to use SyncBN.')
            if cfg.SOLVER.FP16:
                # TODO: Multi-GPU model with FP16
                raise NotImplementedError
                self.logger.info(
                    'Using apex to perform SyncBN and FP16 training')
                torch.distributed.init_process_group(backend='nccl',
                                                     init_method='env://')
                self.model = apex.parallel.convert_syncbn_model(self.model)
            else:
                # Multi-GPU model without FP16
                self.model = nn.DataParallel(self.model)
                self.model = convert_model(self.model)
                self.model.cuda()
                self.logger.info('Using pytorch SyncBN implementation')

                self.optim = make_optimizer(cfg, self.model, num_gpus)
                self.scheduler = WarmupMultiStepLR(self.optim,
                                                   cfg.SOLVER.STEPS,
                                                   cfg.SOLVER.GAMMA,
                                                   cfg.SOLVER.WARMUP_FACTOR,
                                                   cfg.SOLVER.WARMUP_ITERS,
                                                   cfg.SOLVER.WARMUP_METHOD)
                self.scheduler.step()
                self.mix_precision = False
                self.logger.info('Trainer Built')
                return
        else:
            # Single GPU model
            self.model.cuda()
            self.optim = make_optimizer(cfg, self.model, num_gpus)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_ITERS,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.scheduler.step()
            self.mix_precision = False
            if cfg.SOLVER.FP16:
                # Single model using FP16
                self.model, self.optim = amp.initialize(self.model,
                                                        self.optim,
                                                        opt_level='O1')
                self.mix_precision = True
                self.logger.info('Using fp16 training')
            self.logger.info('Trainer Built')
            return

        # TODO: Multi-GPU model with FP16
        raise NotImplementedError
        self.model.to(self.device)
        self.optim = make_optimizer(cfg, self.model, num_gpus)
        self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                           cfg.SOLVER.GAMMA,
                                           cfg.SOLVER.WARMUP_FACTOR,
                                           cfg.SOLVER.WARMUP_ITERS,
                                           cfg.SOLVER.WARMUP_METHOD)
        self.scheduler.step()

        self.model, self.optim = amp.initialize(self.model,
                                                self.optim,
                                                opt_level='O1')
        self.mix_precision = True
        self.logger.info('Using fp16 training')

        self.model = DDP(self.model, delay_allreduce=True)
        self.logger.info('Convert model using apex')
        self.logger.info('Trainer Built')

    def handle_new_batch(self):
        self.batch_cnt += 1
        if self.batch_cnt % self.cfg.SOLVER.LOG_PERIOD == 0:
            self.logger.info('Epoch[{}] Iteration[{}/{}] Loss: {:.3f},'
                             'Acc: {:.3f}, Base Lr: {:.2e}'.format(
                                 self.train_epoch, self.batch_cnt,
                                 len(self.train_dl), self.loss_avg.avg,
                                 self.acc_avg.avg,
                                 self.scheduler.get_lr()[0]))

    def handle_new_epoch(self):
        self.batch_cnt = 1
        self.scheduler.step()
        self.logger.info('Epoch {} done'.format(self.train_epoch))
        self.logger.info('-' * 20)
        if self.train_epoch % self.checkpoint_period == 0:
            self.save()
        if self.train_epoch % self.eval_period == 0:
            self.evaluate()
        self.train_epoch += 1

    def step(self, batch):
        self.model.train()
        self.optim.zero_grad()
        img, target = batch
        img, target = img.cuda(), target.cuda()
        score, feat = self.model(img)
        loss = self.loss_func(score, feat, target)
        if self.mix_precision:
            with amp.scale_loss(loss, self.optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.optim.step()

        acc = (score.max(1)[1] == target).float().mean()

        self.loss_avg.update(loss.cpu().item())
        self.acc_avg.update(acc.cpu().item())

        return self.loss_avg.avg, self.acc_avg.avg

    def evaluate(self):
        self.model.eval()
        num_query = self.num_query
        feats, pids, camids = [], [], []
        with torch.no_grad():
            for batch in tqdm(self.val_dl, total=len(self.val_dl),
                              leave=False):
                data, pid, camid, _ = batch
                data = data.cuda()
                feat = self.model(data).detach().cpu()
                feats.append(feat)
                pids.append(pid)
                camids.append(camid)
        feats = torch.cat(feats, dim=0)
        pids = torch.cat(pids, dim=0)
        camids = torch.cat(camids, dim=0)

        query_feat = feats[:num_query]
        query_pid = pids[:num_query]
        query_camid = camids[:num_query]

        gallery_feat = feats[num_query:]
        gallery_pid = pids[num_query:]
        gallery_camid = camids[num_query:]

        distmat = euclidean_dist(query_feat, gallery_feat)

        cmc, mAP, _ = eval_func(distmat.numpy(),
                                query_pid.numpy(),
                                gallery_pid.numpy(),
                                query_camid.numpy(),
                                gallery_camid.numpy(),
                                use_cython=self.cfg.SOLVER.CYTHON)
        self.logger.info('Validation Result:')
        for r in self.cfg.TEST.CMC:
            self.logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
        self.logger.info('mAP: {:.2%}'.format(mAP))
        self.logger.info('-' * 20)

    def save(self):
        torch.save(
            self.model.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '.pth'))
        torch.save(
            self.optim.state_dict(),
            osp.join(
                self.output_dir, self.cfg.MODEL.NAME + '_epoch' +
                str(self.train_epoch) + '_optim.pth'))