Ejemplo n.º 1
0
 def __init__(self,
              config,
              model,
              criterion,
              train_loader,
              validate_loader,
              post_process=None):
     super(Trainer, self).__init__(config, model, criterion)
     self.show_images_iter = self.config['trainer']['show_images_iter']
     self.train_loader = train_loader
     self.validate_loader = validate_loader
     self.post_process = post_process
     self.train_loader_len = len(train_loader)
     if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
         warmup_iters = config['lr_scheduler']['args'][
             'warmup_epoch'] * self.train_loader_len
         if self.start_epoch > 1:
             self.config['lr_scheduler']['args']['last_epoch'] = (
                 self.start_epoch - 1) * self.train_loader_len
         self.scheduler = WarmupPolyLR(self.optimizer,
                                       max_iters=self.epochs *
                                       self.train_loader_len,
                                       warmup_iters=warmup_iters,
                                       **config['lr_scheduler']['args'])
     if self.validate_loader is not None:
         self.logger_info(
             'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'
             .format(len(self.train_loader.dataset), self.train_loader_len,
                     len(self.validate_loader.dataset),
                     len(self.validate_loader)))
     else:
         self.logger_info(
             'train dataset has {} samples,{} in dataloader'.format(
                 len(self.train_loader.dataset), self.train_loader_len))
Ejemplo n.º 2
0
    def __init__(self, config, model, criterion, train_loader, validate_loader, metric_cls, logger=None,
                 post_process=None):
        super(Trainer, self).__init__(config, model, criterion, logger)
        self.index = 0
        self.show_images_iter = self.config['trainer']['show_images_iter']
        self.train_loader = train_loader
        if validate_loader is not None:
            assert post_process is not None and metric_cls is not None
        self.validate_loader = validate_loader
        self.post_process = post_process
        self.metric_cls = metric_cls
        self.train_loader_len = len(train_loader)
        if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
            warmup_iters = config['lr_scheduler']['args']['warmup_epoch'] * self.train_loader_len
            if self.start_epoch > 1:
                self.config['lr_scheduler']['args']['last_epoch'] = (self.start_epoch - 1) * self.train_loader_len
            self.scheduler = WarmupPolyLR(self.optimizer, max_iters=self.epochs * self.train_loader_len,
                                          warmup_iters=warmup_iters, **config['lr_scheduler']['args'])
        if self.validate_loader is not None:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'.format(
                    len(self.train_loader.dataset), self.train_loader_len, len(self.validate_loader.dataset), len(self.validate_loader)))
        else:
            self.logger_info('train dataset has {} samples,{} in dataloader'.format(len(self.train_loader.dataset), self.train_loader_len))

        if self.config['nni']['flag']:
            dummy_input = next(iter(self.train_loader))
            dummy_input = dummy_input['img'].to(self.device)
            self.pruner = self._initialize('nni', nni.compression.torch, self.model, optimizer=self.optimizer,
                                           dependency_aware=False,
                                           dummy_input=dummy_input)
            # dummy_input = next(iter(self.train_loader))
            # dummy_input = dummy_input['img'].to(self.device)
            # self.pruner = create_pruner(self.model, 'level', self.optimizer, False, dummy_input)
            self.model = self.pruner.compress()
Ejemplo n.º 3
0
    def __init__(self,
                 config,
                 model,
                 criterion,
                 train_loader,
                 validate_loader,
                 metric,
                 converter,
                 post_process=None):
        super(RecTrainer, self).__init__(config, model, criterion)
        self.show_images_iter = self.config['trainer']['show_images_iter']
        self.train_loader = train_loader
        if validate_loader is None:
            assert post_process is not None and metric is not None
        self.validate_loader = validate_loader
        self.post_process = post_process
        self.metric = metric
        self.train_loader_len = len(train_loader)
        # model.preprocess.TPS_SpatialTransformerNetwork 对于tps 模块使用不同的学习率
        if config.arch.get('preprocess', False) and config.arch.preprocess.get(
                'lr_scale', False):
            self.loc_lr = config.arch.preprocess.get('lr_scale') * self.lr
            params_list = filter_params_assign_lr(model,
                                                  {'preprocess': self.loc_lr})
            self.optimizer = self._initialize('optimizer', torch.optim,
                                              params_list)
        if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
            warmup_iters = int(config['lr_scheduler']['args']['warmup_epoch'] *
                               self.train_loader_len)
            if self.start_epoch > 1:
                self.config['lr_scheduler']['args']['last_epoch'] = (
                    self.start_epoch - 1) * self.train_loader_len
            self.scheduler = WarmupPolyLR(self.optimizer,
                                          max_iters=self.epochs *
                                          self.train_loader_len,
                                          warmup_iters=warmup_iters,
                                          **config['lr_scheduler']['args'])
        if self.validate_loader is not None:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'
                .format(len(self.train_loader.dataset), self.train_loader_len,
                        len(self.validate_loader.dataset),
                        len(self.validate_loader)))
        else:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader'.format(
                    len(self.train_loader.dataset), self.train_loader_len))

        self.converter = converter
        self.best_acc = 0
Ejemplo n.º 4
0
    def __init__(self,
                 config,
                 model,
                 criterion,
                 train_loader,
                 weights_init=None):
        super(Trainer, self).__init__(config, model, criterion, weights_init)
        self.show_images_interval = self.config['trainer'][
            'show_images_interval']
        self.test_path = self.config['data_loader']['args']['dataset'][
            'val_data_path']
        self.train_loader = train_loader
        self.train_loader_len = len(train_loader)
        if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
            base_lr = config['optimizer']['args']['lr']
            self.scheduler = WarmupPolyLR(self.optimizer,
                                          target_lr=base_lr * 1e-2,
                                          max_iters=self.epochs *
                                          self.train_loader_len)

        self.logger.info(
            'train dataset has {} samples,{} in dataloader'.format(
                len(self.train_loader.dataset), self.train_loader_len))
Ejemplo n.º 5
0
class RecTrainer(BaseTrainer):
    def __init__(self,
                 config,
                 model,
                 criterion,
                 train_loader,
                 validate_loader,
                 metric,
                 converter,
                 post_process=None):
        super(RecTrainer, self).__init__(config, model, criterion)
        self.show_images_iter = self.config['trainer']['show_images_iter']
        self.train_loader = train_loader
        if validate_loader is None:
            assert post_process is not None and metric is not None
        self.validate_loader = validate_loader
        self.post_process = post_process
        self.metric = metric
        self.train_loader_len = len(train_loader)
        # model.preprocess.TPS_SpatialTransformerNetwork 对于tps 模块使用不同的学习率
        if config.arch.get('preprocess', False) and config.arch.preprocess.get(
                'lr_scale', False):
            self.loc_lr = config.arch.preprocess.get('lr_scale') * self.lr
            params_list = filter_params_assign_lr(model,
                                                  {'preprocess': self.loc_lr})
            self.optimizer = self._initialize('optimizer', torch.optim,
                                              params_list)
        if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
            warmup_iters = int(config['lr_scheduler']['args']['warmup_epoch'] *
                               self.train_loader_len)
            if self.start_epoch > 1:
                self.config['lr_scheduler']['args']['last_epoch'] = (
                    self.start_epoch - 1) * self.train_loader_len
            self.scheduler = WarmupPolyLR(self.optimizer,
                                          max_iters=self.epochs *
                                          self.train_loader_len,
                                          warmup_iters=warmup_iters,
                                          **config['lr_scheduler']['args'])
        if self.validate_loader is not None:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'
                .format(len(self.train_loader.dataset), self.train_loader_len,
                        len(self.validate_loader.dataset),
                        len(self.validate_loader)))
        else:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader'.format(
                    len(self.train_loader.dataset), self.train_loader_len))

        self.converter = converter
        self.best_acc = 0

    def _train_epoch(self, epoch):
        self.model.train()
        epoch_start = time.time()
        batch_start = time.time()
        train_loss = 0.
        # running_metric_text = runningScore(2)
        lr = self.optimizer.param_groups[0]['lr']
        self.metric.reset()
        for i, batch in enumerate(self.train_loader):
            if i >= self.train_loader_len:
                # if i >= 1:
                break
            self.global_step += 1
            lr = self.optimizer.param_groups[0]['lr']
            # 解析label
            batch['text'], batch['length'] = self.converter.encode(
                batch['labels'])
            # 数据进行转换和丢到gpu
            for key, value in batch.items():
                if value is not None:
                    if isinstance(value, torch.Tensor):
                        batch[key] = value.to(self.device)
            cur_batch_size = batch['img'].size()[0]
            preds = self.model(batch['img'])
            loss_dict = self.criterion(preds, batch['text'], batch['length'],
                                       cur_batch_size)
            # backward
            self.optimizer.zero_grad()
            loss_dict['loss'].backward()
            self.optimizer.step()
            if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
                self.scheduler.step()

            # loss 和 acc 记录到日志
            loss_str = 'loss: {:.4f}, '.format(loss_dict['loss'].item())
            for idx, (key, value) in enumerate(loss_dict.items()):
                loss_dict[key] = value.item()
                if key == 'loss':
                    continue
                loss_str += '{}: {:.4f}'.format(key, loss_dict[key])
                if idx < len(loss_dict) - 1:
                    loss_str += ', '

            train_loss += loss_dict['loss']
            preds_prob = F.softmax(preds, dim=2)
            preds_prob, pred_index = preds_prob.max(dim=2)
            pred_str = self.converter.decode(pred_index)
            self.metric.measure(pred_str, batch['labels'], preds_prob)
            acc = self.metric.avg['acc']['true']
            edit_distance = self.metric.avg['edit']

            if self.global_step % self.log_iter == 0:
                batch_time = time.time() - batch_start
                self.logger_info(
                    '[{}/{}], [{}/{}], global_step: {}, speed: {:.1f} samples/sec, acc: {:.4f}, edit_distance: {:.4f}, {}, lr:{:.6}, time:{:.2f}'
                    .format(epoch, self.epochs, i + 1, self.train_loader_len,
                            self.global_step,
                            self.log_iter * cur_batch_size / batch_time, acc,
                            edit_distance, loss_str, lr, batch_time))
                batch_start = time.time()

            # if self.tensorboard_enable and self.config['local_rank'] == 0:
            #     # write tensorboard
            #     for key, value in loss_dict.items():
            #         self.writer.add_scalar('TRAIN/LOSS/{}'.format(key), value, self.global_step)
            #     self.writer.add_scalar('TRAIN/ACC_DIS/acc', acc, self.global_step)
            #     self.writer.add_scalar('TRAIN/ACC_DIS/edit_distance', edit_distance, self.global_step)
            #     self.writer.add_scalar('TRAIN/lr', lr, self.global_step)
            #     if self.global_step % self.show_images_iter == 0:
            #         # show images on tensorboard
            #         self.inverse_normalize(batch['img'])
            #         self.writer.add_images('TRAIN/imgs', batch['img'], self.global_step)
            #         # shrink_labels and threshold_labels
            #         shrink_labels = batch['labels']
            #         threshold_labels = batch['threshold_map']
            #         shrink_labels[shrink_labels <= 0.5] = 0
            #         shrink_labels[shrink_labels > 0.5] = 1
            #         show_label = torch.cat([shrink_labels, threshold_labels])
            #         show_label = vutils.make_grid(show_label.unsqueeze(1), nrow=cur_batch_size, normalize=False, padding=20, pad_value=1)
            #         self.writer.add_image('TRAIN/gt', show_label, self.global_step)
            #         # model output
            #         show_pred = []
            #         for kk in range(preds.shape[1]):
            #             show_pred.append(preds[:, kk, :, :])
            #         show_pred = torch.cat(show_pred)
            #         show_pred = vutils.make_grid(show_pred.unsqueeze(1), nrow=cur_batch_size, normalize=False, padding=20, pad_value=1)
            #         self.writer.add_image('TRAIN/preds', show_pred, self.global_step)
        return {
            'train_loss': train_loss / self.train_loader_len,
            'lr': lr,
            'time': time.time() - epoch_start,
            'epoch': epoch
        }

    def _eval(self, epoch):
        self.model.eval()
        total_frame = 0.0
        total_time = 0.0
        self.metric.reset()
        for i, batch in tqdm(enumerate(self.validate_loader),
                             total=len(self.validate_loader),
                             desc='test model'):
            with torch.no_grad():
                # 数据进行转换和丢到gpu
                for key, value in batch.items():
                    if value is not None:
                        if isinstance(value, torch.Tensor):
                            batch[key] = value.to(self.device)
                start = time.time()
                preds = self.model(batch['img'])
                preds_prob = F.softmax(preds, dim=2)
                preds_prob, pred_index = preds_prob.max(dim=2)
                pred_str = self.converter.decode(pred_index)
                self.metric.measure(pred_str, batch['labels'], preds_prob)
                total_frame += batch['img'].size()[0]
                total_time += time.time() - start
        acc = self.metric.avg['acc']['true']
        edit = self.metric.avg['edit']
        self.logger_info('FPS:{}'.format(total_frame / total_time))

        return acc, edit

    def _on_epoch_finish(self):
        self.logger_info(
            '[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format(
                self.epoch_result['epoch'], self.epochs,
                self.epoch_result['train_loss'], self.epoch_result['time'],
                self.epoch_result['lr']))
        net_save_path = '{}/model_latest.pth'.format(self.checkpoint_dir)

        if self.config['local_rank'] == 0:
            save_best = False
            if self.validate_loader is not None and self.metric is not None:  # 使用f1作为最优模型指标
                acc, edit = self._eval(self.epoch_result['epoch'])

                # if self.tensorboard_enable:
                #     self.writer.add_scalar('EVAL/recall', recall, self.global_step)
                #     self.writer.add_scalar('EVAL/precision', precision, self.global_step)
                #     self.writer.add_scalar('EVAL/hmean', hmean, self.global_step)
                self.logger_info(
                    'test: precision: {:.6f}, edit_distance: {:.4f}'.format(
                        acc, edit))

                if acc >= self.best_acc:
                    self.best_acc = acc
                    save_best = True
            else:
                if self.epoch_result['train_loss'] <= self.metrics[
                        'train_loss']:
                    save_best = True
                    self.metrics['train_loss'] = self.epoch_result[
                        'train_loss']
            self._save_checkpoint(self.epoch_result['epoch'], net_save_path,
                                  save_best)

    def _on_train_finish(self):
        for k, v in self.metrics.items():
            self.logger_info('{}:{}'.format(k, v))
        self.logger_info('finish train')

    def inverse_normalize(self, batch_img):
        if self.UN_Normalize:
            batch_img[:, 0, :, :] = batch_img[:, 0, :, :] * self.normalize_std[
                0] + self.normalize_mean[0]
            batch_img[:, 1, :, :] = batch_img[:, 1, :, :] * self.normalize_std[
                1] + self.normalize_mean[1]
            batch_img[:, 2, :, :] = batch_img[:, 2, :, :] * self.normalize_std[
                2] + self.normalize_mean[2]
Ejemplo n.º 6
0
class Trainer(BaseTrainer):
    def __init__(self,
                 config,
                 model,
                 criterion,
                 train_loader,
                 validate_loader,
                 post_process=None):
        super(Trainer, self).__init__(config, model, criterion)
        self.show_images_iter = self.config['trainer']['show_images_iter']
        self.train_loader = train_loader
        self.validate_loader = validate_loader
        self.post_process = post_process
        self.train_loader_len = len(train_loader)
        if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
            warmup_iters = config['lr_scheduler']['args'][
                'warmup_epoch'] * self.train_loader_len
            if self.start_epoch > 1:
                self.config['lr_scheduler']['args']['last_epoch'] = (
                    self.start_epoch - 1) * self.train_loader_len
            self.scheduler = WarmupPolyLR(self.optimizer,
                                          max_iters=self.epochs *
                                          self.train_loader_len,
                                          warmup_iters=warmup_iters,
                                          **config['lr_scheduler']['args'])
        if self.validate_loader is not None:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'
                .format(len(self.train_loader.dataset), self.train_loader_len,
                        len(self.validate_loader.dataset),
                        len(self.validate_loader)))
        else:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader'.format(
                    len(self.train_loader.dataset), self.train_loader_len))

    def _train_epoch(self, epoch):
        self.model.train()
        epoch_start = time.time()
        batch_start = time.time()
        train_loss = 0.
        running_metric_melons = runningScore(3)
        lr = self.optimizer.param_groups[0]['lr']

        for i, batch in enumerate(self.train_loader):
            if i >= self.train_loader_len:
                break
            self.global_step += 1
            lr = self.optimizer.param_groups[0]['lr']
            print(self.optimizer, self.config['local_rank'])

            # 数据进行转换和丢到gpu
            for key, value in batch.items():
                if value is not None:
                    if isinstance(value, torch.Tensor):
                        batch[key] = value.to(self.device)

            cur_batch_size = batch['img'].size()[0]
            # print('image name :',batch['img_name'])
            self.optimizer.zero_grad()
            preds = self.model(batch['img'])
            loss_dict = self.criterion(preds, batch)
            # backward
            if isinstance(preds, tuple):
                preds = preds[0]
            # print('preds:', preds.shape)

            # 反向传播时:在求导时开启侦测
            # print(loss_dict['loss'])
            # exit()
            reduce_loss = self.all_reduce_tensor(loss_dict['loss'])
            with torch.autograd.detect_anomaly():
                # loss.backward()
                loss_dict['loss'].backward()
            self.optimizer.step()
            if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
                self.scheduler.step()
            # acc iou
            target = batch['label']
            h, w = target.size(1), target.size(2)
            scale_pred = F.interpolate(input=preds,
                                       size=(h, w),
                                       mode='bilinear',
                                       align_corners=True)
            label_preds = torch.argmax(scale_pred, dim=1)
            running_metric_melons.update(target.data.cpu().numpy(),
                                         label_preds.data.cpu().numpy())
            score_, _ = running_metric_melons.get_scores()

            # loss 和 acc 记录到日志
            loss_str = 'loss: {:.4f}, '.format(reduce_loss.item())
            for idx, (key, value) in enumerate(loss_dict.items()):
                loss_dict[key] = value.item()
                if key == 'loss':
                    continue
                loss_str += '{}: {:.4f}'.format(key, loss_dict[key])
                if idx < len(loss_dict) - 1:
                    loss_str += ', '

            train_loss += loss_dict['loss']
            print(train_loss / self.train_loader_len,
                  self.config['local_rank'])
            acc = score_['Mean Acc']
            iou_Mean_map = score_['Mean IoU']
            if self.global_step % self.log_iter == 0:
                batch_time = time.time() - batch_start
                self.logger_info(
                    '[{}/{}], [{}/{}], global_step: {}, speed: {:.1f} samples/sec, acc: {:.4f}, iou_Mean_map: {:.4f}, {}, lr:{:.6}, time:{:.2f}'
                    .format(epoch, self.epochs, i + 1, self.train_loader_len,
                            self.global_step,
                            self.log_iter * cur_batch_size / batch_time, acc,
                            iou_Mean_map, loss_str, lr, batch_time))
                batch_start = time.time()
            # print('loss_str', loss_str)

            if self.tensorboard_enable and self.config['local_rank'] == 0:
                # write tensorboard
                for key, value in loss_dict.items():
                    self.writer.add_scalar('TRAIN/LOSS/{}'.format(key), value,
                                           self.global_step)
                    self.writer.add_scalar('TRAIN/ACC_IOU/acc', acc,
                                           self.global_step)
                    self.writer.add_scalar('TRAIN/ACC_IOU/iou_Mean_map',
                                           iou_Mean_map, self.global_step)
                    self.writer.add_scalar('TRAIN/lr', lr, self.global_step)
                if self.global_step % self.show_images_iter == 0:
                    # show images on tensorboard
                    self.inverse_normalize(batch['img'])
                    preds_colors = decode_predictions(preds, cur_batch_size, 3)
                    self.writer.add_images('TRAIN/imgs',
                                           batch['img'][0].unsqueeze(0),
                                           self.global_step)
                    target = batch['label']
                    # (8, 256, 320, 3)

                    targets_colors = decode_labels(target, cur_batch_size, 3)
                    self.writer.add_image('TRAIN/labels',
                                          targets_colors[0],
                                          self.global_step,
                                          dataformats='HWC')
                    self.writer.add_image('TRAIN/preds',
                                          preds_colors[0],
                                          self.global_step,
                                          dataformats='HWC')
        return {
            'train_loss': train_loss / self.train_loader_len,
            'lr': lr,
            'time': time.time() - epoch_start,
            'epoch': epoch,
            'MeanIoU': iou_Mean_map
        }

    def all_reduce_tensor(self, tensor2, norm=True):
        if self.distributed:
            return self.all_reduce_tensor2(tensor2,
                                           world_size=self.world_size,
                                           norm=norm)
        else:
            return torch.mean(tensor)

    def all_reduce_tensor2(self,
                           tensor2,
                           op=dist.ReduceOp.SUM,
                           world_size=1,
                           norm=True):
        tensor = tensor2.clone()
        dist.all_reduce(tensor2, op)
        if norm:
            tensor2.div_(world_size)

        return tensor2

    def _eval(self, epoch):
        self.model.eval()
        # torch.cuda.empty_cache()  # speed up evaluating after training finished
        total_frame = 0.0
        total_time = 0.0
        running_metric_melons = runningScore(3)
        mean_acc = []
        mean_iou = []
        for i, batch in tqdm(enumerate(self.validate_loader),
                             total=len(self.validate_loader),
                             desc='test model'):
            with torch.no_grad():
                # 数据进行转换和丢到gpu
                for key, value in batch.items():
                    if value is not None:
                        if isinstance(value, torch.Tensor):
                            batch[key] = value.to(self.device)
                start = time.time()
                # print(batch['img'].shape)
                # exit()
                preds = self.model(batch['img'])

                if isinstance(preds, tuple):
                    preds = preds[0]
                target = batch['label']
                h, w = target.size(1), target.size(2)
                scale_pred = F.interpolate(input=preds,
                                           size=(h, w),
                                           mode='bilinear',
                                           align_corners=True)
                label_preds = torch.argmax(scale_pred, dim=1)

                running_metric_melons.update(target.data.cpu().numpy(),
                                             label_preds.data.cpu().numpy())
                score_, _ = running_metric_melons.get_scores()
                total_time += time.time() - start
                total_frame += batch['img'].size()[0]
                acc = score_['Mean Acc']
                iou_Mean_map = score_['Mean IoU']
                mean_acc.append(acc)
                mean_iou.append(iou_Mean_map)

        print('FPS:{}'.format(total_frame / total_time))
        return np.array(mean_acc).mean(), np.array(mean_iou).mean()

    def _on_epoch_finish(self):
        self.logger_info(
            '[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format(
                self.epoch_result['epoch'], self.epochs,
                self.epoch_result['train_loss'], self.epoch_result['time'],
                self.epoch_result['lr']))
        net_save_path = '{}/model_latest.pth'.format(self.checkpoint_dir)

        if self.config['local_rank'] == 0:
            save_best = False
            if self.validate_loader is not None:  # 使用meaniou作为最优模型指标
                acc, MeanIoU = self._eval(self.epoch_result['epoch'])

                if self.tensorboard_enable:
                    self.writer.add_scalar('EVAL/acc', acc, self.global_step)
                    self.writer.add_scalar('EVAL/MeanIoU', MeanIoU,
                                           self.global_step)
                self.logger_info('test: acc: {:.6f}, MeanIoU: {:.6f}'.format(
                    acc, MeanIoU))

                if MeanIoU >= self.metrics['MeanIoU']:
                    save_best = True
                    self.metrics['train_loss'] = self.epoch_result[
                        'train_loss']
                    self.metrics['MeanIoU'] = MeanIoU
                    self.metrics['Mean Acc'] = acc
                    self.metrics['best_model_epoch'] = self.epoch_result[
                        'epoch']

            else:
                if self.epoch_result['MeanIoU'] <= self.metrics['MeanIoU']:
                    save_best = True
                    self.metrics['MeanIoU'] = self.epoch_result['MeanIoU']
                    self.metrics['best_model_epoch'] = self.epoch_result[
                        'epoch']
            best_str = 'current best, '
            for k, v in self.metrics.items():
                best_str += '{}: {:.6f}, '.format(k, v)
            self.logger_info(best_str)
            self._save_checkpoint(self.epoch_result['epoch'], net_save_path,
                                  save_best)

    def _on_train_finish(self):
        # for k, v in self.metrics.items():
        #     self.logger_info('{}:{}'.format(k, v))
        self.logger_info('finish train')
Ejemplo n.º 7
0
class Trainer(BaseTrainer):
    def __init__(self,
                 config,
                 model,
                 criterion,
                 train_loader,
                 validate_loader,
                 metric_cls,
                 post_process=None):
        super(Trainer, self).__init__(config, model, criterion)
        self.show_images_iter = self.config['trainer']['show_images_iter']
        self.train_loader = train_loader
        if validate_loader is None:
            assert post_process is not None and metric_cls is not None
        self.validate_loader = validate_loader
        self.post_process = post_process
        self.metric_cls = metric_cls
        self.train_loader_len = len(train_loader)
        if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
            warmup_iters = config['lr_scheduler']['args'][
                'warmup_epoch'] * self.train_loader_len
            if self.start_epoch > 1:
                self.config['lr_scheduler']['args']['last_epoch'] = (
                    self.start_epoch - 1) * self.train_loader_len
            self.scheduler = WarmupPolyLR(self.optimizer,
                                          max_iters=self.epochs *
                                          self.train_loader_len,
                                          warmup_iters=warmup_iters,
                                          **config['lr_scheduler']['args'])
        if self.validate_loader is not None:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'
                .format(len(self.train_loader.dataset), self.train_loader_len,
                        len(self.validate_loader.dataset),
                        len(self.validate_loader)))
        else:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader'.format(
                    len(self.train_loader.dataset), self.train_loader_len))

    def _train_epoch(self, epoch):
        self.model.train()
        epoch_start = time.time()
        batch_start = time.time()
        train_loss = 0.
        running_metric_text = runningScore(2)
        lr = self.optimizer.param_groups[0]['lr']

        for i, batch in enumerate(self.train_loader):
            if i >= self.train_loader_len:
                break
            self.global_step += 1
            lr = self.optimizer.param_groups[0]['lr']

            # 数据进行转换和丢到gpu
            for key, value in batch.items():
                if value is not None:
                    if isinstance(value, torch.Tensor):
                        batch[key] = value.to(self.device)
            cur_batch_size = batch['img'].size()[0]

            preds = self.model(batch['img'])
            loss_dict = self.criterion(preds, batch)
            # backward
            self.optimizer.zero_grad()
            loss_dict['loss'].backward()
            self.optimizer.step()
            if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
                self.scheduler.step()
            # acc iou
            score_shrink_map = cal_text_score(
                preds[:, 0, :, :],
                batch['shrink_map'],
                batch['shrink_mask'],
                running_metric_text,
                thred=self.config['post_processing']['args']['thresh'])

            # loss 和 acc 记录到日志
            loss_str = 'loss: {:.4f}, '.format(loss_dict['loss'].item())
            for idx, (key, value) in enumerate(loss_dict.items()):
                loss_dict[key] = value.item()
                if key == 'loss':
                    continue
                loss_str += '{}: {:.4f}'.format(key, loss_dict[key])
                if idx < len(loss_dict) - 1:
                    loss_str += ', '

            train_loss += loss_dict['loss']
            acc = score_shrink_map['Mean Acc']
            iou_shrink_map = score_shrink_map['Mean IoU']

            if self.global_step % self.log_iter == 0:
                batch_time = time.time() - batch_start
                self.logger_info(
                    '[{}/{}], [{}/{}], global_step: {}, speed: {:.1f} samples/sec, acc: {:.4f}, iou_shrink_map: {:.4f}, {}, lr:{:.6}, time:{:.2f}'
                    .format(epoch, self.epochs, i + 1, self.train_loader_len,
                            self.global_step,
                            self.log_iter * cur_batch_size / batch_time, acc,
                            iou_shrink_map, loss_str, lr, batch_time))
                batch_start = time.time()

            if self.tensorboard_enable and self.config['local_rank'] == 0:
                # write tensorboard
                for key, value in loss_dict.items():
                    self.writer.add_scalar('TRAIN/LOSS/{}'.format(key), value,
                                           self.global_step)
                self.writer.add_scalar('TRAIN/ACC_IOU/acc', acc,
                                       self.global_step)
                self.writer.add_scalar('TRAIN/ACC_IOU/iou_shrink_map',
                                       iou_shrink_map, self.global_step)
                self.writer.add_scalar('TRAIN/lr', lr, self.global_step)
                if self.global_step % self.show_images_iter == 0:
                    # show images on tensorboard
                    self.inverse_normalize(batch['img'])
                    self.writer.add_images('TRAIN/imgs', batch['img'],
                                           self.global_step)
                    # shrink_labels and threshold_labels
                    shrink_labels = batch['shrink_map']
                    threshold_labels = batch['threshold_map']
                    shrink_labels[shrink_labels <= 0.5] = 0
                    shrink_labels[shrink_labels > 0.5] = 1
                    show_label = torch.cat([shrink_labels, threshold_labels])
                    show_label = vutils.make_grid(show_label.unsqueeze(1),
                                                  nrow=cur_batch_size,
                                                  normalize=False,
                                                  padding=20,
                                                  pad_value=1)
                    self.writer.add_image('TRAIN/gt', show_label,
                                          self.global_step)
                    # model output
                    show_pred = []
                    for kk in range(preds.shape[1]):
                        show_pred.append(preds[:, kk, :, :])
                    show_pred = torch.cat(show_pred)
                    show_pred = vutils.make_grid(show_pred.unsqueeze(1),
                                                 nrow=cur_batch_size,
                                                 normalize=False,
                                                 padding=20,
                                                 pad_value=1)
                    self.writer.add_image('TRAIN/preds', show_pred,
                                          self.global_step)
        return {
            'train_loss': train_loss / self.train_loader_len,
            'lr': lr,
            'time': time.time() - epoch_start,
            'epoch': epoch
        }

    def _eval(self, epoch):
        self.model.eval()
        # torch.cuda.empty_cache()  # speed up evaluating after training finished
        raw_metrics = []
        total_frame = 0.0
        total_time = 0.0
        for i, batch in tqdm(enumerate(self.validate_loader),
                             total=len(self.validate_loader),
                             desc='test model'):
            with torch.no_grad():
                # 数据进行转换和丢到gpu
                for key, value in batch.items():
                    if value is not None:
                        if isinstance(value, torch.Tensor):
                            batch[key] = value.to(self.device)
                start = time.time()
                preds = self.model(batch['img'])
                boxes, scores = self.post_process(
                    batch,
                    preds,
                    is_output_polygon=self.metric_cls.is_output_polygon)
                total_frame += batch['img'].size()[0]
                total_time += time.time() - start
                raw_metric = self.metric_cls.validate_measure(
                    batch, (boxes, scores))
                raw_metrics.append(raw_metric)
        metrics = self.metric_cls.gather_measure(raw_metrics)
        self.logger_info('FPS:{}'.format(total_frame / total_time))
        return metrics['recall'].avg, metrics['precision'].avg, metrics[
            'fmeasure'].avg

    def _on_epoch_finish(self):
        self.logger_info(
            '[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format(
                self.epoch_result['epoch'], self.epochs,
                self.epoch_result['train_loss'], self.epoch_result['time'],
                self.epoch_result['lr']))
        net_save_path = '{}/model_latest.pth'.format(self.checkpoint_dir)

        if self.config['local_rank'] == 0:
            save_best = False
            if self.validate_loader is not None and self.metric_cls is not None:  # 使用f1作为最优模型指标
                recall, precision, hmean = self._eval(
                    self.epoch_result['epoch'])

                if self.tensorboard_enable:
                    self.writer.add_scalar('EVAL/recall', recall,
                                           self.global_step)
                    self.writer.add_scalar('EVAL/precision', precision,
                                           self.global_step)
                    self.writer.add_scalar('EVAL/hmean', hmean,
                                           self.global_step)
                self.logger_info(
                    'test: recall: {:.6f}, precision: {:.6f}, f1: {:.6f}'.
                    format(recall, precision, hmean))

                if hmean >= self.metrics['hmean']:
                    save_best = True
                    self.metrics['train_loss'] = self.epoch_result[
                        'train_loss']
                    self.metrics['hmean'] = hmean
                    self.metrics['precision'] = precision
                    self.metrics['recall'] = recall
                    self.metrics['best_model_epoch'] = self.epoch_result[
                        'epoch']
            else:
                if self.epoch_result['train_loss'] <= self.metrics[
                        'train_loss']:
                    save_best = True
                    self.metrics['train_loss'] = self.epoch_result[
                        'train_loss']
                    self.metrics['best_model_epoch'] = self.epoch_result[
                        'epoch']
            best_str = 'current best, '
            for k, v in self.metrics.items():
                best_str += '{}: {:.6f}, '.format(k, v)
            self.logger_info(best_str)
            self._save_checkpoint(self.epoch_result['epoch'], net_save_path,
                                  save_best)

    def _on_train_finish(self):
        for k, v in self.metrics.items():
            self.logger_info('{}:{}'.format(k, v))
        self.logger_info('finish train')
Ejemplo n.º 8
0
    def __init__(self, config, model, criterion, train_loader, validate_loader, metric_cls, post_process=None):
        config['trainer']['output_dir'] = os.path.join(str(pathlib.Path(os.path.abspath(__name__)).parent.parent),
                                                       config['trainer']['output_dir'])
        self.save_dir = config['trainer']['output_dir']
        self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')

        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        self.global_step = 0
        self.start_epoch = 0
        self.config = config
        self.model = model
        self.criterion = criterion

        # logger
        self.epochs = self.config['trainer']['epochs']
        self.log_iter = self.config['trainer']['log_iter']

        anyconfig.dump(config, os.path.join(self.save_dir, 'config.yaml'))
        self.logger = setup_logger(os.path.join(self.save_dir, 'train.log'))
        self.logger_info(pformat(self.config))

        # device
        if self.config['trainer']['CUDA_VISIBLE_DEVICES'] is not None:
            os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(self.config['trainer']['CUDA_VISIBLE_DEVICES']))
            if torch.cuda.is_available():
                self.with_cuda = True
                torch.backends.cudnn.benchmark = True
                self.device = torch.device('cuda')
                if torch.cuda.device_count() > 1:
                    self.is_distributed = True
                    torch.cuda.manual_seed_all(self.config['trainer']['seed'])
                    self.model = torch.nn.DataParallel(self.model)
                else:
                    self.is_distributed = False
                    torch.cuda.manual_seed(self.config['trainer']['seed'])
            else:
                self.is_distributed = False
                self.with_cuda = False
                self.device = torch.device("cpu")
                torch.manual_seed(self.config['trainer']['seed'])
        else:
            self.is_distributed = False
            self.with_cuda = False
            self.device = torch.device("cpu")
            torch.manual_seed(self.config['trainer']['seed'])
        self.logger_info('train with device {} {} and pytorch {}'.format(self.device,
                                                                         'distributed' if self.is_distributed is not None and self.is_distributed else 'single',
                                                                         torch.__version__))
        self.model.to(self.device)

        # metrics and optimizer
        self.metrics = {'recall': 0, 'precision': 0, 'hmean': 0, 'train_loss': float('inf'), 'best_model_epoch': 0}
        self.optimizer = self._initialize('optimizer', torch.optim, model.parameters())

        # checkpoint
        if self.config['trainer']['resume_checkpoint'] != '':
            self._load_checkpoint(self.config['trainer']['resume_checkpoint'], False)
            self.net_save_path_best = ''
        else:
            net_save_path_latest = os.path.join(self.checkpoint_dir, "model_latest.pth")
            if os.path.isfile(net_save_path_latest):
                self._load_checkpoint(net_save_path_latest, False)

            self.net_save_path_best = os.path.join(self.checkpoint_dir, "model_best*.pth")
            if glob.glob(self.net_save_path_best):
                self.net_save_path_best = glob.glob(self.net_save_path_best)[0]
                self._load_checkpoint(self.net_save_path_best, True)
            else:
                self.net_save_path_best = ''

        # normalize
        self.UN_Normalize = False
        for t in self.config['dataset']['train']['dataset']['args']['transforms']:
            if t['type'] == 'Normalize':
                self.normalize_mean = t['args']['mean']
                self.normalize_std = t['args']['std']
                self.UN_Normalize = True

        self.show_images_iter = self.config['trainer']['show_images_iter']
        self.train_loader = train_loader
        if validate_loader is not None:
            assert post_process is not None and metric_cls is not None
        self.validate_loader = validate_loader
        self.post_process = post_process
        self.metric_cls = metric_cls
        self.train_loader_len = len(train_loader)

        # lr_scheduler
        warmup_iters = config['lr_scheduler']['args']['warmup_epoch'] * self.train_loader_len
        if self.start_epoch > 1:
            self.config['lr_scheduler']['args']['last_epoch'] = (self.start_epoch - 1) * self.train_loader_len
        self.scheduler = WarmupPolyLR(self.optimizer, max_iters=self.epochs * self.train_loader_len,
                                      warmup_iters=warmup_iters, **config['lr_scheduler']['args'])

        self.logger_info(
            'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'.format(
                len(self.train_loader.dataset), self.train_loader_len, len(self.validate_loader.dataset),
                len(self.validate_loader)))

        self.epoch_result = {'train_loss': 0, 'lr': 0, 'time': 0, 'epoch': 0}
Ejemplo n.º 9
0
class Trainer:

    def __init__(self, config, model, criterion, train_loader, validate_loader, metric_cls, post_process=None):
        config['trainer']['output_dir'] = os.path.join(str(pathlib.Path(os.path.abspath(__name__)).parent.parent),
                                                       config['trainer']['output_dir'])
        self.save_dir = config['trainer']['output_dir']
        self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')

        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        self.global_step = 0
        self.start_epoch = 0
        self.config = config
        self.model = model
        self.criterion = criterion

        # logger
        self.epochs = self.config['trainer']['epochs']
        self.log_iter = self.config['trainer']['log_iter']

        anyconfig.dump(config, os.path.join(self.save_dir, 'config.yaml'))
        self.logger = setup_logger(os.path.join(self.save_dir, 'train.log'))
        self.logger_info(pformat(self.config))

        # device
        if self.config['trainer']['CUDA_VISIBLE_DEVICES'] is not None:
            os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(self.config['trainer']['CUDA_VISIBLE_DEVICES']))
            if torch.cuda.is_available():
                self.with_cuda = True
                torch.backends.cudnn.benchmark = True
                self.device = torch.device('cuda')
                if torch.cuda.device_count() > 1:
                    self.is_distributed = True
                    torch.cuda.manual_seed_all(self.config['trainer']['seed'])
                    self.model = torch.nn.DataParallel(self.model)
                else:
                    self.is_distributed = False
                    torch.cuda.manual_seed(self.config['trainer']['seed'])
            else:
                self.is_distributed = False
                self.with_cuda = False
                self.device = torch.device("cpu")
                torch.manual_seed(self.config['trainer']['seed'])
        else:
            self.is_distributed = False
            self.with_cuda = False
            self.device = torch.device("cpu")
            torch.manual_seed(self.config['trainer']['seed'])
        self.logger_info('train with device {} {} and pytorch {}'.format(self.device,
                                                                         'distributed' if self.is_distributed is not None and self.is_distributed else 'single',
                                                                         torch.__version__))
        self.model.to(self.device)

        # metrics and optimizer
        self.metrics = {'recall': 0, 'precision': 0, 'hmean': 0, 'train_loss': float('inf'), 'best_model_epoch': 0}
        self.optimizer = self._initialize('optimizer', torch.optim, model.parameters())

        # checkpoint
        if self.config['trainer']['resume_checkpoint'] != '':
            self._load_checkpoint(self.config['trainer']['resume_checkpoint'], False)
            self.net_save_path_best = ''
        else:
            net_save_path_latest = os.path.join(self.checkpoint_dir, "model_latest.pth")
            if os.path.isfile(net_save_path_latest):
                self._load_checkpoint(net_save_path_latest, False)

            self.net_save_path_best = os.path.join(self.checkpoint_dir, "model_best*.pth")
            if glob.glob(self.net_save_path_best):
                self.net_save_path_best = glob.glob(self.net_save_path_best)[0]
                self._load_checkpoint(self.net_save_path_best, True)
            else:
                self.net_save_path_best = ''

        # normalize
        self.UN_Normalize = False
        for t in self.config['dataset']['train']['dataset']['args']['transforms']:
            if t['type'] == 'Normalize':
                self.normalize_mean = t['args']['mean']
                self.normalize_std = t['args']['std']
                self.UN_Normalize = True

        self.show_images_iter = self.config['trainer']['show_images_iter']
        self.train_loader = train_loader
        if validate_loader is not None:
            assert post_process is not None and metric_cls is not None
        self.validate_loader = validate_loader
        self.post_process = post_process
        self.metric_cls = metric_cls
        self.train_loader_len = len(train_loader)

        # lr_scheduler
        warmup_iters = config['lr_scheduler']['args']['warmup_epoch'] * self.train_loader_len
        if self.start_epoch > 1:
            self.config['lr_scheduler']['args']['last_epoch'] = (self.start_epoch - 1) * self.train_loader_len
        self.scheduler = WarmupPolyLR(self.optimizer, max_iters=self.epochs * self.train_loader_len,
                                      warmup_iters=warmup_iters, **config['lr_scheduler']['args'])

        self.logger_info(
            'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'.format(
                len(self.train_loader.dataset), self.train_loader_len, len(self.validate_loader.dataset),
                len(self.validate_loader)))

        self.epoch_result = {'train_loss': 0, 'lr': 0, 'time': 0, 'epoch': 0}

    def train(self):
        for epoch in range(self.start_epoch + 1, self.epochs + 1):
            self.epoch_result = self._train_epoch(epoch)
            self._on_epoch_finish()
        self._on_train_finish()

    def _train_epoch(self, epoch):
        self.model.train()
        epoch_start = time.time()
        batch_start = time.time()
        train_loss = 0.
        running_metric_text = runningScore(2)
        lr = self.optimizer.param_groups[0]['lr']

        for i, batch in enumerate(self.train_loader):
            if i >= self.train_loader_len:
                break
            self.global_step += 1
            lr = self.optimizer.param_groups[0]['lr']

            # 数据进行转换和丢到gpu
            for key, value in batch.items():
                if value is not None:
                    if isinstance(value, torch.Tensor):
                        batch[key] = value.to(self.device)
            cur_batch_size = batch['img'].size()[0]

            preds = self.model(batch['img'])
            loss_dict = self.criterion(preds, batch)
            # backward
            self.optimizer.zero_grad()
            loss_dict['loss'].backward()
            self.optimizer.step()
            self.scheduler.step()

            # acc iou
            score_shrink_map = cal_text_score(preds[:, 0, :, :], batch['shrink_map'], batch['shrink_mask'],
                                              running_metric_text,
                                              thred=self.config['post_processing']['args']['thresh'])

            # loss 和 acc 记录到日志
            loss_str = 'loss: {:.4f}, '.format(loss_dict['loss'].item())
            for idx, (key, value) in enumerate(loss_dict.items()):
                loss_dict[key] = value.item()
                if key == 'loss':
                    continue
                loss_str += '{}: {:.4f}'.format(key, loss_dict[key])
                if idx < len(loss_dict) - 1:
                    loss_str += ', '

            train_loss += loss_dict['loss']
            acc = score_shrink_map['Mean Acc']
            iou_shrink_map = score_shrink_map['Mean IoU']

            if self.global_step % self.log_iter == 0:
                batch_time = time.time() - batch_start
                self.logger_info(
                    '[{}/{}], [{}/{}], global_step: {}, speed: {:.1f} samples/sec, acc: {:.4f}, iou_shrink_map: {:.4f}, {}lr:{:.6}, time:{:.2f}'.format(
                        epoch, self.epochs, i + 1, self.train_loader_len, self.global_step,
                                            self.log_iter * cur_batch_size / batch_time, acc, iou_shrink_map, loss_str,
                        lr, batch_time))
                batch_start = time.time()

        return {'train_loss': train_loss / self.train_loader_len, 'lr': lr, 'time': time.time() - epoch_start,
                'epoch': epoch}

    def _eval(self):
        self.model.eval()
        torch.cuda.empty_cache()
        raw_metrics = []
        total_frame = 0.0
        total_time = 0.0
        for i, batch in tqdm(enumerate(self.validate_loader), total=len(self.validate_loader), desc='test model'):
            with torch.no_grad():
                for key, value in batch.items():
                    if value is not None:
                        if isinstance(value, torch.Tensor):
                            batch[key] = value.to(self.device)
                start = time.time()
                preds = self.model(batch['img'])
                boxes, scores = self.post_process(batch, preds, is_output_polygon=self.metric_cls.is_output_polygon)
                total_frame += batch['img'].size()[0]
                total_time += time.time() - start
                raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores))
                raw_metrics.append(raw_metric)
        metrics = self.metric_cls.gather_measure(raw_metrics)
        self.logger_info('FPS:{}'.format(total_frame / total_time))
        return metrics['recall'].avg, metrics['precision'].avg, metrics['fmeasure'].avg

    def _on_epoch_finish(self):
        self.logger_info('[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format(
            self.epoch_result['epoch'], self.epochs, self.epoch_result['train_loss'], self.epoch_result['time'],
            self.epoch_result['lr']))

        net_save_path_latest = '{}/model_latest.pth'.format(self.checkpoint_dir)
        self.logger_info("Saving latest checkpoint: {}".format(net_save_path_latest))
        self._save_checkpoint(self.epoch_result['epoch'], net_save_path_latest)

        recall, precision, hmean = self._eval()
        self.logger_info('test: recall: {}, precision: {}, hmean: {}'.format(recall, precision, hmean))

        if hmean > self.metrics['hmean']:
            if self.net_save_path_best != '':
                os.remove(self.net_save_path_best)

            self.metrics['train_loss'] = self.epoch_result['train_loss']
            self.metrics['hmean'] = hmean
            self.metrics['precision'] = precision
            self.metrics['recall'] = recall
            self.metrics['best_model_epoch'] = self.epoch_result['epoch']

            self.net_save_path_best = '{}/model_best_recall_{:.6f}_precision_{:.6f}_hmean_{:.6f}.pth'.format(
                self.checkpoint_dir,
                self.metrics['recall'],
                self.metrics['precision'],
                self.metrics['hmean'])
            self.logger_info("Saving best checkpoint: {}".format(self.net_save_path_best))
            self._save_checkpoint(self.epoch_result['epoch'], self.net_save_path_best)

        best_str = 'current best:'
        for k, v in self.metrics.items():
            best_str += '{}: {:.6f}, '.format(k, v)
        self.logger_info(best_str)

    def _on_train_finish(self):
        for k, v in self.metrics.items():
            self.logger_info('{}:{}'.format(k, v))
        self.logger_info('finish train')

    def _save_checkpoint(self, epoch, file_name):
        state_dict = self.model.state_dict()
        state = {
            'epoch': epoch,
            'global_step': self.global_step,
            'state_dict': state_dict,
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'config': self.config,
            'metrics': self.metrics
        }
        filename = os.path.join(self.checkpoint_dir, file_name)
        torch.save(state, filename)

    def _load_checkpoint(self, checkpoint_path, is_best):
        self.logger_info("Loading checkpoint: {} ...".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

        if is_best:
            self.metrics = checkpoint['metrics']
            self.logger_info("metrics resume from checkpoint {}".format(checkpoint_path))
            return

        state_dict = checkpoint['state_dict']
        if self.with_cuda:
            if self.is_distributed:
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    key = 'module.' + k if not k.startswith('module.') else k
                    new_state_dict[key] = v
                self.model.load_state_dict(new_state_dict)
            else:
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    key = k[7:] if k.startswith('module.') else k
                    new_state_dict[key] = v
                self.model.load_state_dict(new_state_dict)
        else:
            self.model.load_state_dict(state_dict)
        self.global_step = checkpoint['global_step']
        self.start_epoch = checkpoint['epoch']
        self.config['lr_scheduler']['args']['last_epoch'] = self.start_epoch
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        if 'metrics' in checkpoint:
            self.metrics = checkpoint['metrics']
        if self.with_cuda:
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(self.device)
        self.logger_info("resume from checkpoint {} (epoch {})".format(checkpoint_path, self.start_epoch))

    def _initialize(self, name, module, *args, **kwargs):
        module_name = self.config[name]['type']
        module_args = self.config[name]['args']
        assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
        module_args.update(kwargs)
        return getattr(module, module_name)(*args, **module_args)

    def logger_info(self, s):
        self.logger.info(s)
Ejemplo n.º 10
0
    def __init__(self,
                 config,
                 model,
                 criterion,
                 metric_cls,
                 train_loader,
                 validate_loader,
                 post_process=None):
        config['trainer']['output_dir'] = os.path.join(
            str(pathlib.Path(os.path.abspath(__name__)).parent),
            config['trainer']['output_dir'])
        config['name'] = config['name'] + '_' + model.name
        self.save_dir = os.path.join(config['trainer']['output_dir'],
                                     config['name'])
        self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')

        if config['trainer']['resume_checkpoint'] == '' and config['trainer'][
                'finetune_checkpoint'] == '':
            shutil.rmtree(self.save_dir, ignore_errors=True)
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        self.global_step = 0
        self.start_epoch = 0
        self.config = config
        self.model = model
        self.criterion = criterion
        self.metric_cls = metric_cls
        # logger and tensorboard
        self.epochs = self.config['trainer']['epochs']
        self.log_iter = self.config['trainer']['log_iter']
        self.tensorboard_enable = self.config['trainer']['tensorboard']
        if config['local_rank'] == 0:
            anyconfig.dump(config, os.path.join(self.save_dir, 'config.yaml'))
            self.logger = setup_logger(os.path.join(self.save_dir,
                                                    'train.log'))
            self.logger_info(pformat(self.config))

        # device
        torch.manual_seed(self.config['trainer']['seed'])  # 为CPU设置随机种子
        if torch.cuda.device_count() > 0 and torch.cuda.is_available():
            self.with_cuda = True
            torch.backends.cudnn.benchmark = True
            self.device = torch.device("cuda")
            torch.cuda.manual_seed(
                self.config['trainer']['seed'])  # 为当前GPU设置随机种子
            torch.cuda.manual_seed_all(
                self.config['trainer']['seed'])  # 为所有GPU设置随机种子
        else:
            self.with_cuda = False
            self.device = torch.device("cpu")
        self.logger_info('train with device {} and pytorch {}'.format(
            self.device, torch.__version__))

        self.optimizer = self._initialize('optimizer', torch.optim,
                                          model.parameters())

        # resume or finetune
        if self.config['trainer']['resume_checkpoint'] != '':
            self._load_checkpoint(self.config['trainer']['resume_checkpoint'],
                                  resume=True)
        elif self.config['trainer']['finetune_checkpoint'] != '':
            self._load_checkpoint(
                self.config['trainer']['finetune_checkpoint'], resume=False)

        if self.config['lr_scheduler']['type'] != 'WarmupPolyLR':
            self.scheduler = self._initialize('lr_scheduler',
                                              torch.optim.lr_scheduler,
                                              self.optimizer)
        self.metrics = {
            'recall': 0,
            'precision': 0,
            'hmean': 0,
            'train_loss': float('inf'),
            'best_model_epoch': 0
        }
        self.model.to(self.device)

        # 分布式训练
        if torch.cuda.device_count() > 1:
            local_rank = config['local_rank']
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[local_rank],
                output_device=local_rank,
                broadcast_buffers=False,
                find_unused_parameters=True)

        self.show_images_iter = self.config['trainer']['show_images_iter']
        self.train_loader = train_loader
        if validate_loader is not None:
            assert post_process is not None
        self.validate_loader = validate_loader
        self.post_process = post_process
        self.train_loader_len = len(train_loader)
        if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
            warmup_iters = config['lr_scheduler']['args'][
                'warmup_epoch'] * self.train_loader_len
            if self.start_epoch > 1:
                self.config['lr_scheduler']['args']['last_epoch'] = (
                    self.start_epoch - 1) * self.train_loader_len
            self.scheduler = WarmupPolyLR(self.optimizer,
                                          max_iters=self.epochs *
                                          self.train_loader_len,
                                          warmup_iters=warmup_iters,
                                          **config['lr_scheduler']['args'])
        if self.validate_loader is not None:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'
                .format(len(self.train_loader.dataset), self.train_loader_len,
                        len(self.validate_loader.dataset),
                        len(self.validate_loader)))
        else:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader'.format(
                    len(self.train_loader.dataset), self.train_loader_len))

        if self.tensorboard_enable and config['local_rank'] == 0:
            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(self.save_dir)
            try:
                dummy_input = torch.zeros(1, 3, 640, 640).to(self.device)
                self.writer.add_graph(self.model, dummy_input)
                torch.cuda.empty_cache()
            except:
                import traceback
                self.logger.error(traceback.format_exc())
                self.logger.warn('add graph to tensorboard failed')
Ejemplo n.º 11
0
class Trainer():
    def __init__(self,
                 config,
                 model,
                 criterion,
                 metric_cls,
                 train_loader,
                 validate_loader,
                 post_process=None):
        config['trainer']['output_dir'] = os.path.join(
            str(pathlib.Path(os.path.abspath(__name__)).parent),
            config['trainer']['output_dir'])
        config['name'] = config['name'] + '_' + model.name
        self.save_dir = os.path.join(config['trainer']['output_dir'],
                                     config['name'])
        self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')

        if config['trainer']['resume_checkpoint'] == '' and config['trainer'][
                'finetune_checkpoint'] == '':
            shutil.rmtree(self.save_dir, ignore_errors=True)
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        self.global_step = 0
        self.start_epoch = 0
        self.config = config
        self.model = model
        self.criterion = criterion
        self.metric_cls = metric_cls
        # logger and tensorboard
        self.epochs = self.config['trainer']['epochs']
        self.log_iter = self.config['trainer']['log_iter']
        self.tensorboard_enable = self.config['trainer']['tensorboard']
        if config['local_rank'] == 0:
            anyconfig.dump(config, os.path.join(self.save_dir, 'config.yaml'))
            self.logger = setup_logger(os.path.join(self.save_dir,
                                                    'train.log'))
            self.logger_info(pformat(self.config))

        # device
        torch.manual_seed(self.config['trainer']['seed'])  # 为CPU设置随机种子
        if torch.cuda.device_count() > 0 and torch.cuda.is_available():
            self.with_cuda = True
            torch.backends.cudnn.benchmark = True
            self.device = torch.device("cuda")
            torch.cuda.manual_seed(
                self.config['trainer']['seed'])  # 为当前GPU设置随机种子
            torch.cuda.manual_seed_all(
                self.config['trainer']['seed'])  # 为所有GPU设置随机种子
        else:
            self.with_cuda = False
            self.device = torch.device("cpu")
        self.logger_info('train with device {} and pytorch {}'.format(
            self.device, torch.__version__))

        self.optimizer = self._initialize('optimizer', torch.optim,
                                          model.parameters())

        # resume or finetune
        if self.config['trainer']['resume_checkpoint'] != '':
            self._load_checkpoint(self.config['trainer']['resume_checkpoint'],
                                  resume=True)
        elif self.config['trainer']['finetune_checkpoint'] != '':
            self._load_checkpoint(
                self.config['trainer']['finetune_checkpoint'], resume=False)

        if self.config['lr_scheduler']['type'] != 'WarmupPolyLR':
            self.scheduler = self._initialize('lr_scheduler',
                                              torch.optim.lr_scheduler,
                                              self.optimizer)
        self.metrics = {
            'recall': 0,
            'precision': 0,
            'hmean': 0,
            'train_loss': float('inf'),
            'best_model_epoch': 0
        }
        self.model.to(self.device)

        # 分布式训练
        if torch.cuda.device_count() > 1:
            local_rank = config['local_rank']
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[local_rank],
                output_device=local_rank,
                broadcast_buffers=False,
                find_unused_parameters=True)

        self.show_images_iter = self.config['trainer']['show_images_iter']
        self.train_loader = train_loader
        if validate_loader is not None:
            assert post_process is not None
        self.validate_loader = validate_loader
        self.post_process = post_process
        self.train_loader_len = len(train_loader)
        if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
            warmup_iters = config['lr_scheduler']['args'][
                'warmup_epoch'] * self.train_loader_len
            if self.start_epoch > 1:
                self.config['lr_scheduler']['args']['last_epoch'] = (
                    self.start_epoch - 1) * self.train_loader_len
            self.scheduler = WarmupPolyLR(self.optimizer,
                                          max_iters=self.epochs *
                                          self.train_loader_len,
                                          warmup_iters=warmup_iters,
                                          **config['lr_scheduler']['args'])
        if self.validate_loader is not None:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'
                .format(len(self.train_loader.dataset), self.train_loader_len,
                        len(self.validate_loader.dataset),
                        len(self.validate_loader)))
        else:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader'.format(
                    len(self.train_loader.dataset), self.train_loader_len))

        if self.tensorboard_enable and config['local_rank'] == 0:
            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(self.save_dir)
            try:
                dummy_input = torch.zeros(1, 3, 640, 640).to(self.device)
                self.writer.add_graph(self.model, dummy_input)
                torch.cuda.empty_cache()
            except:
                import traceback
                self.logger.error(traceback.format_exc())
                self.logger.warn('add graph to tensorboard failed')

    def train(self):
        """
        Full training logic
        """
        for epoch in range(self.start_epoch + 1, self.epochs + 1):
            if self.config['distributed']:
                self.train_loader.sampler.set_epoch(epoch)
            self.epoch_result = self._train_epoch(epoch)
            if self.config['lr_scheduler']['type'] != 'WarmupPolyLR':
                self.scheduler.step()
            self._on_epoch_finish()
        if self.config['local_rank'] == 0 and self.tensorboard_enable:
            self.writer.close()
        self._on_train_finish()

    def _train_epoch(self, epoch):
        self.model.train()
        epoch_start = time.time()
        batch_start = time.time()
        train_loss = 0.
        lr = self.optimizer.param_groups[0]['lr']

        for i, batch in enumerate(self.train_loader):
            cur_batch_size = batch['img_tensor'].size()[0]
            if i >= self.train_loader_len:
                break
            self.global_step += 1
            lr = self.optimizer.param_groups[0]['lr']

            for key, value in batch.items():
                if value is not None:
                    if isinstance(value, torch.Tensor):
                        batch[key] = value.to(self.device)

            preds = self.model(batch['img_tensor'])
            loss = self.criterion(preds, batch['label'])
            # backward
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
                self.scheduler.step()

            train_loss += loss

            if self.global_step % self.log_iter == 0:
                batch_time = time.time() - batch_start
                self.logger_info(
                    '[{}/{}], [{}/{}], global_step: {}, speed: {:.1f} samples/sec, loss: {:.6f},lr:{:.6}, time:{:.2f}'
                    .format(epoch, self.epochs, i + 1, self.train_loader_len,
                            self.global_step,
                            self.log_iter * cur_batch_size / batch_time, loss,
                            lr, batch_time))
                batch_start = time.time()

            if self.tensorboard_enable and self.config['local_rank'] == 0:
                self.writer.add_scalar('TRAIN/LOSS/{}'.format('loss'), loss,
                                       self.global_step)
                self.writer.add_scalar('TRAIN/lr', lr, self.global_step)
                if self.global_step % self.show_images_iter == 0:
                    boxes, pred_imgs, act_imgs = self.post_process(
                        batch['img'], preds, save_img=True)

                    show_act = vutils.make_grid(
                        [transforms.ToTensor()(x) for x in act_imgs],
                        nrow=cur_batch_size,
                        normalize=False,
                        padding=20,
                        pad_value=1)

                    show_label = vutils.make_grid(
                        [transforms.ToTensor()(x) for x in batch['draw_gt']],
                        nrow=cur_batch_size,
                        normalize=False,
                        padding=20,
                        pad_value=1)

                    show_pred = vutils.make_grid(
                        [transforms.ToTensor()(x) for x in pred_imgs],
                        nrow=cur_batch_size,
                        normalize=False,
                        padding=20,
                        pad_value=1)

                    self.writer.add_images('TRAIN/activation',
                                           show_act.unsqueeze(0),
                                           self.global_step)
                    self.writer.add_images('TRAIN/gt', show_label.unsqueeze(0),
                                           self.global_step)
                    self.writer.add_images('TRAIN/preds',
                                           show_pred.unsqueeze(0),
                                           self.global_step)

        return {
            'train_loss': train_loss / self.train_loader_len,
            'lr': lr,
            'time': time.time() - epoch_start,
            'epoch': epoch
        }

    def _eval(self):
        self.model.eval()
        raw_metrics = []
        total_frame = 0.0
        total_time = 0.0
        for i, batch in tqdm(enumerate(self.validate_loader),
                             total=len(self.validate_loader),
                             desc='test model'):
            with torch.no_grad():
                for key, value in batch.items():
                    if value is not None:
                        if isinstance(value, torch.Tensor):
                            batch[key] = value.to(self.device)
                start = time.time()
                preds = self.model(batch['img_tensor'])
                boxes, _, _ = self.post_process(batch['img'], preds)
                total_frame += batch['img_tensor'].size()[0]
                total_time += time.time() - start
                raw_metric = self.metric_cls.validate_measure(batch, boxes)
                raw_metrics.append(raw_metric)
        metrics = self.metric_cls.gather_measure(raw_metrics)
        self.logger_info('FPS:{}'.format(total_frame / total_time))
        return metrics['recall'].avg, metrics['precision'].avg, metrics[
            'fmeasure'].avg

    def _on_epoch_finish(self):
        self.logger_info(
            '[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format(
                self.epoch_result['epoch'], self.epochs,
                self.epoch_result['train_loss'], self.epoch_result['time'],
                self.epoch_result['lr']))
        net_save_path = '{}/model_latest.pth'.format(self.checkpoint_dir)
        net_save_path_best = '{}/model_best.pth'.format(self.checkpoint_dir)

        if self.config['local_rank'] == 0:
            self._save_checkpoint(self.epoch_result['epoch'], net_save_path)
            save_best = False
            if self.validate_loader is not None and self.metric_cls is not None:  # 使用f1作为最优模型指标
                recall, precision, hmean = self._eval()
                self.logger_info(
                    'test: recall: {:.6f}, precision: {:.6f}, f1: {:.6f}'.
                    format(recall, precision, hmean))

                if hmean >= self.metrics['hmean']:
                    save_best = True
                    self.metrics['train_loss'] = self.epoch_result[
                        'train_loss']
                    self.metrics['hmean'] = hmean
                    self.metrics['precision'] = precision
                    self.metrics['recall'] = recall
                    self.metrics['best_model_epoch'] = self.epoch_result[
                        'epoch']
            else:
                if self.epoch_result['train_loss'] <= self.metrics[
                        'train_loss']:
                    save_best = True
                    self.metrics['train_loss'] = self.epoch_result[
                        'train_loss']
                    self.metrics['best_model_epoch'] = self.epoch_result[
                        'epoch']
            best_str = 'current best, '
            for k, v in self.metrics.items():
                best_str += '{}: {:.6f}, '.format(k, v)
            self.logger_info(best_str)
            if save_best:
                import shutil
                shutil.copy(net_save_path, net_save_path_best)
                self.logger_info(
                    "Saving current best: {}".format(net_save_path_best))
            else:
                self.logger_info("Saving checkpoint: {}".format(net_save_path))

    def _on_train_finish(self):
        for k, v in self.metrics.items():
            self.logger_info('{}:{}'.format(k, v))
        self.logger_info('finish train')

    def _save_checkpoint(self, epoch, file_name):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
        """
        state_dict = self.model.module.state_dict(
        ) if self.config['distributed'] else self.model.state_dict()
        state = {
            'epoch': epoch,
            'global_step': self.global_step,
            'state_dict': state_dict,
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'config': self.config
        }
        filename = os.path.join(self.checkpoint_dir, file_name)
        torch.save(state, filename)

    def _load_checkpoint(self, checkpoint_path, resume):
        """
        Resume from saved checkpoints
        :param checkpoint_path: Checkpoint path to be resumed
        """
        self.logger_info("Loading checkpoint: {} ...".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path,
                                map_location=torch.device('cpu'))
        self.model.load_state_dict(checkpoint['state_dict'], strict=resume)
        if resume:
            self.global_step = checkpoint['global_step']
            self.start_epoch = checkpoint['epoch']
            self.config['lr_scheduler']['args'][
                'last_epoch'] = self.start_epoch
            # self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            if self.with_cuda:
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to(self.device)
            self.logger_info("resume from checkpoint {} (epoch {})".format(
                checkpoint_path, self.start_epoch))
        else:
            self.logger_info(
                "finetune from checkpoint {}".format(checkpoint_path))

    def _initialize(self, name, module, *args, **kwargs):
        module_name = self.config[name]['type']
        module_args = self.config[name]['args']
        assert all([
            k not in module_args for k in kwargs
        ]), 'Overwriting kwargs given in config file is not allowed'
        module_args.update(kwargs)
        return getattr(module, module_name)(*args, **module_args)

    def logger_info(self, s):
        if self.config['local_rank'] == 0:
            self.logger.info(s)
Ejemplo n.º 12
0
class Trainer(BaseTrainer):
    def __init__(self,
                 config,
                 model,
                 criterion,
                 train_loader,
                 weights_init=None):
        super(Trainer, self).__init__(config, model, criterion, weights_init)
        self.show_images_interval = self.config['trainer'][
            'show_images_interval']
        self.test_path = self.config['data_loader']['args']['dataset'][
            'val_data_path']
        self.train_loader = train_loader
        self.train_loader_len = len(train_loader)
        if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
            base_lr = config['optimizer']['args']['lr']
            self.scheduler = WarmupPolyLR(self.optimizer,
                                          target_lr=base_lr * 1e-2,
                                          max_iters=self.epochs *
                                          self.train_loader_len)

        self.logger.info(
            'train dataset has {} samples,{} in dataloader'.format(
                len(self.train_loader.dataset), self.train_loader_len))

    def _train_epoch(self, epoch):
        self.model.train()
        epoch_start = time.time()
        batch_start = time.time()
        train_loss = 0.
        running_metric_text = runningScore(2)
        lr = self.optimizer.param_groups[0]['lr']
        for i, (images, shrink_labels,
                threshold_labels) in enumerate(self.train_loader):
            if i >= self.train_loader_len:
                break
            self.global_step += 1
            lr = self.optimizer.param_groups[0]['lr']

            # 数据进行转换和丢到gpu
            cur_batch_size = images.size()[0]
            images, shrink_labels, threshold_labels = images.to(
                self.device), shrink_labels.to(
                    self.device), threshold_labels.to(self.device)

            preds = self.model(images)
            loss_all, loss_shrink_map, loss_binary_map, loss_threshold_map = self.criterion(
                preds, shrink_labels, threshold_labels)
            # backward
            self.optimizer.zero_grad()
            loss_all.backward()
            self.optimizer.step()
            if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
                self.scheduler.step()
            # acc iou
            score_shrink_map = cal_text_score(preds[:, 0, :, :],
                                              shrink_labels,
                                              running_metric_text,
                                              thred=0.5)

            # loss 和 acc 记录到日志
            loss_all = loss_all.item()
            loss_shrink_map = loss_shrink_map.item()
            loss_binary_map = loss_binary_map.item()
            loss_threshold_map = loss_threshold_map.item()
            train_loss += loss_all
            acc = score_shrink_map['Mean Acc']
            iou_shrink_map = score_shrink_map['Mean IoU']

            if (i + 1) % self.display_interval == 0:
                batch_time = time.time() - batch_start
                self.logger.info(
                    '[{}/{}], [{}/{}], global_step: {}, Speed: {:.1f} samples/sec, acc: {:.4f}, iou_shrink_map: {:.4f}, loss_all: {:.4f}, loss_shrink_map: {:.4f}, loss_binary_map: {:.4f}, loss_threshold_map: {:.4f}, lr:{:.6}, time:{:.2f}'
                    .format(
                        epoch, self.epochs, i + 1, self.train_loader_len,
                        self.global_step,
                        self.display_interval * cur_batch_size / batch_time,
                        acc, iou_shrink_map, loss_all, loss_shrink_map,
                        loss_binary_map, loss_threshold_map, lr, batch_time))
                batch_start = time.time()

            if self.tensorboard_enable:
                # write tensorboard
                self.writer.add_scalar('TRAIN/LOSS/loss_all', loss_all,
                                       self.global_step)
                self.writer.add_scalar('TRAIN/LOSS/loss_shrink_map',
                                       loss_shrink_map, self.global_step)
                self.writer.add_scalar('TRAIN/LOSS/loss_binary_map',
                                       loss_binary_map, self.global_step)
                self.writer.add_scalar('TRAIN/LOSS/loss_threshold_map',
                                       loss_threshold_map, self.global_step)
                self.writer.add_scalar('TRAIN/ACC_IOU/acc', acc,
                                       self.global_step)
                self.writer.add_scalar('TRAIN/ACC_IOU/iou_shrink_map',
                                       iou_shrink_map, self.global_step)
                self.writer.add_scalar('TRAIN/lr', lr, self.global_step)
                if i % self.show_images_interval == 0:
                    # show images on tensorboard
                    self.writer.add_images('TRAIN/imgs', images,
                                           self.global_step)
                    # shrink_labels and threshold_labels
                    shrink_labels[shrink_labels <= 0.5] = 0
                    shrink_labels[shrink_labels > 0.5] = 1
                    show_label = torch.cat([shrink_labels, threshold_labels])
                    show_label = vutils.make_grid(show_label.unsqueeze(1),
                                                  nrow=cur_batch_size,
                                                  normalize=False,
                                                  padding=20,
                                                  pad_value=1)
                    self.writer.add_image('TRAIN/gt', show_label,
                                          self.global_step)
                    # model output
                    show_pred = torch.cat([
                        preds[:, 0, :, :], preds[:, 1, :, :], preds[:, 2, :, :]
                    ])
                    show_pred = vutils.make_grid(show_pred.unsqueeze(1),
                                                 nrow=cur_batch_size,
                                                 normalize=False,
                                                 padding=20,
                                                 pad_value=1)
                    self.writer.add_image('TRAIN/preds', show_pred,
                                          self.global_step)

        return {
            'train_loss': train_loss / self.train_loader_len,
            'lr': lr,
            'time': time.time() - epoch_start,
            'epoch': epoch
        }

    def _eval(self):
        self.model.eval()
        # torch.cuda.empty_cache()  # speed up evaluating after training finished
        img_path = os.path.join(self.test_path, 'img')
        gt_path = os.path.join(self.test_path, 'gt')
        result_save_path = os.path.join(self.save_dir, 'result')
        if os.path.exists(result_save_path):
            shutil.rmtree(result_save_path, ignore_errors=True)
        if not os.path.exists(result_save_path):
            os.makedirs(result_save_path)
        short_size = 736
        # 预测所有测试图片
        img_paths = [os.path.join(img_path, x) for x in os.listdir(img_path)]
        for img_path in tqdm(img_paths, desc='test models'):
            img_name = os.path.basename(img_path).split('.')[0]
            save_name = os.path.join(result_save_path,
                                     'res_' + img_name + '.txt')

            assert os.path.exists(img_path), 'file is not exists'
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            h, w = img.shape[:2]
            scale = short_size / min(h, w)
            img = cv2.resize(img, None, fx=scale, fy=scale)
            # 将图片由(w,h)变为(1,img_channel,h,w)
            tensor = transforms.ToTensor()(img)
            tensor = tensor.unsqueeze_(0)

            tensor = tensor.to(self.device)
            with torch.no_grad():
                torch.cuda.synchronize(self.device)
                preds = self.model(tensor)[0]
                torch.cuda.synchronize(self.device)
                preds, boxes_list = decode(preds)
                scale = (preds.shape[1] / w, preds.shape[0] / h)
                if len(boxes_list):
                    boxes_list = boxes_list / scale
            np.savetxt(save_name,
                       boxes_list.reshape(-1, 8),
                       delimiter=',',
                       fmt='%d')
        # 开始计算 recall precision f1
        result_dict = cal_recall_precison_f1(gt_path=gt_path,
                                             result_path=result_save_path)
        return result_dict['recall'], result_dict['precision'], result_dict[
            'hmean']

    def _on_epoch_finish(self):
        self.logger.info(
            '[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format(
                self.epoch_result['epoch'], self.epochs,
                self.epoch_result['train_loss'], self.epoch_result['time'],
                self.epoch_result['lr']))
        net_save_path = '{}/DBNet_latest.pth'.format(self.checkpoint_dir)

        save_best = False
        if self.config['trainer']['metrics'] == 'hmean':  # 使用f1作为最优模型指标
            recall, precision, hmean = self._eval()

            if self.tensorboard_enable:
                self.writer.add_scalar('EVAL/recall', recall, self.global_step)
                self.writer.add_scalar('EVAL/precision', precision,
                                       self.global_step)
                self.writer.add_scalar('EVAL/hmean', hmean, self.global_step)
            self.logger.info(
                'test: recall: {:.6f}, precision: {:.6f}, f1: {:.6f}'.format(
                    recall, precision, hmean))

            if hmean > self.metrics['hmean']:
                save_best = True
                self.metrics['train_loss'] = self.epoch_result['train_loss']
                self.metrics['hmean'] = hmean
                self.metrics['precision'] = precision
                self.metrics['recall'] = recall
                self.metrics['best_model'] = net_save_path
        else:
            if self.epoch_result['train_loss'] < self.metrics['train_loss']:
                save_best = True
                self.metrics['train_loss'] = self.epoch_result['train_loss']
                self.metrics['best_model'] = net_save_path
        self._save_checkpoint(self.epoch_result['epoch'], net_save_path,
                              save_best)

    def _on_train_finish(self):
        for k, v in self.metrics.items():
            self.logger.info('{}:{}'.format(k, v))
        self.logger.info('finish train')