예제 #1
0
class FCNSegmentor(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = DictAverageMeter()
        self.val_losses = DictAverageMeter()
        self.seg_running_score = SegRunningScore(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_model_manager = ModelManager(configer)
        self.seg_data_loader = DataLoader(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.runner_state = dict()

        self._init_model()

    def _init_model(self):
        self.seg_net = self.seg_model_manager.get_seg_model()
        self.seg_net = RunnerHelper.load_net(self, self.seg_net)

        self.optimizer, self.scheduler = Trainer.init(
            self._get_parameters(), self.configer.get('solver'))

        self.train_loader = self.seg_data_loader.get_trainloader()
        self.val_loader = self.seg_data_loader.get_valloader()

        self.loss = self.seg_model_manager.get_seg_loss()

    def _get_parameters(self):
        lr_1 = []
        lr_10 = []
        params_dict = dict(self.seg_net.named_parameters())
        for key, value in params_dict.items():
            if 'backbone' not in key:
                lr_10.append(value)
            else:
                lr_1.append(value)

        params = [{
            'params': lr_1,
            'lr': self.configer.get('solver', 'lr')['base_lr']
        }, {
            'params': lr_10,
            'lr': self.configer.get('solver', 'lr')['base_lr'] * 1.0
        }]
        return params

    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.

        for i, data_dict in enumerate(self.train_loader):
            Trainer.update(self,
                           warm_list=(0, ),
                           solver_dict=self.configer.get('solver'))
            self.data_time.update(time.time() - start_time)

            # Forward pass.
            data_dict = RunnerHelper.to_device(self, data_dict)
            out = self.seg_net(data_dict)
            # Compute the loss of the train batch & backward.
            loss_dict = self.loss(out)
            loss = loss_dict['loss']
            self.train_losses.update(
                {key: loss.item()
                 for key, loss in loss_dict.items()}, data_dict['img'].size(0))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.runner_state['iters'] % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {4}\tLoss = {3}\n'.format(
                        self.runner_state['epoch'],
                        self.runner_state['iters'],
                        self.configer.get('solver', 'display_iter'),
                        self.train_losses.info(),
                        RunnerHelper.get_lr(self.optimizer),
                        batch_time=self.batch_time,
                        data_time=self.data_time))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.runner_state['iters'] % self.configer.get('solver.save_iters') == 0 \
                    and self.configer.get('local_rank') == 0:
                RunnerHelper.save_net(self, self.seg_net)

            if self.configer.get('solver', 'lr')['metric'] == 'iters' \
                    and self.runner_state['iters'] == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            if self.runner_state['iters'] % self.configer.get('solver', 'test_interval') == 0 \
                    and not self.configer.get('network.distributed'):
                self.val()

        self.runner_state['epoch'] += 1

    def val(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()

        data_loader = self.val_loader if data_loader is None else data_loader
        for j, data_dict in enumerate(data_loader):
            data_dict = RunnerHelper.to_device(self, data_dict)
            with torch.no_grad():
                # Forward pass.
                out = self.seg_net(data_dict)
                loss_dict = self.loss(out)
                # Compute the loss of the val batch.
                out_dict, _ = RunnerHelper.gather(self, out)

            self.val_losses.update(
                {key: loss.item()
                 for key, loss in loss_dict.items()}, data_dict['img'].size(0))
            self._update_running_score(out_dict['out'],
                                       DCHelper.tolist(data_dict['meta']))

            # Update the vars of the val phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.runner_state['performance'] = self.seg_running_score.get_mean_iou(
        )
        self.runner_state['val_loss'] = self.val_losses.avg['loss']
        RunnerHelper.save_net(
            self,
            self.seg_net,
            performance=self.seg_running_score.get_mean_iou(),
            val_loss=self.val_losses.avg['loss'])

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                 'Loss = {0}\n'.format(self.val_losses.info(),
                                       batch_time=self.batch_time))
        Log.info('Mean IOU: {}\n'.format(
            self.seg_running_score.get_mean_iou()))
        Log.info('Pixel ACC: {}\n'.format(
            self.seg_running_score.get_pixel_acc()))
        self.batch_time.reset()
        self.val_losses.reset()
        self.seg_running_score.reset()
        self.seg_net.train()

    def _update_running_score(self, pred, metas):
        pred = pred.permute(0, 2, 3, 1)
        for i in range(pred.size(0)):
            border_size = metas[i]['border_wh']
            ori_target = metas[i]['ori_target']
            total_logits = cv2.resize(
                pred[i, :border_size[1], :border_size[0]].cpu().numpy(),
                tuple(metas[i]['ori_img_wh']),
                interpolation=cv2.INTER_CUBIC)
            labelmap = np.argmax(total_logits, axis=-1)
            self.seg_running_score.update(labelmap[None], ori_target[None])
예제 #2
0
class FCNSegmentorTest(object):
    def __init__(self, configer):
        self.configer = configer
        self.blob_helper = BlobHelper(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_parser = SegParser(configer)
        self.seg_model_manager = ModelManager(configer)
        self.seg_data_loader = DataLoader(configer)
        self.test_loader = TestDataLoader(configer)
        self.device = torch.device(
            'cpu' if self.configer.get('gpu') is None else 'cuda')
        self.seg_net = None

        self._init_model()

    def _init_model(self):
        self.seg_net = self.seg_model_manager.get_seg_model()
        self.seg_net = RunnerHelper.load_net(self, self.seg_net)
        self.seg_net.eval()

    def test(self, test_dir, out_dir):
        for _, data_dict in enumerate(
                self.test_loader.get_testloader(test_dir=test_dir)):
            total_logits = None
            if self.configer.get('test', 'mode') == 'ss_test':
                total_logits = self.ss_test(data_dict)

            elif self.configer.get('test', 'mode') == 'sscrop_test':
                total_logits = self.sscrop_test(data_dict,
                                                params_dict=self.configer.get(
                                                    'test', 'sscrop_test'))

            elif self.configer.get('test', 'mode') == 'ms_test':
                total_logits = self.ms_test(data_dict,
                                            params_dict=self.configer.get(
                                                'test', 'ms_test'))

            elif self.configer.get('test', 'mode') == 'mscrop_test':
                total_logits = self.mscrop_test(data_dict,
                                                params_dict=self.configer.get(
                                                    'test', 'mscrop_test'))

            else:
                Log.error('Invalid test mode:{}'.format(
                    self.configer.get('test', 'mode')))
                exit(1)

            meta_list = DCHelper.tolist(data_dict['meta'])
            for i in range(len(meta_list)):
                label_map = np.argmax(total_logits[i], axis=-1)
                label_img = np.array(label_map, dtype=np.uint8)
                ori_img_bgr = ImageHelper.read_image(meta_list[i]['img_path'],
                                                     tool='cv2',
                                                     mode='BGR')
                image_canvas = self.seg_parser.colorize(
                    label_img, image_canvas=ori_img_bgr)
                ImageHelper.save(image_canvas,
                                 save_path=os.path.join(
                                     out_dir, 'vis/{}.png'.format(
                                         meta_list[i]['filename'])))

                if self.configer.get('data.label_list',
                                     default=None) is not None:
                    label_img = self.__relabel(label_img)

                if self.configer.get('data.reduce_zero_label', default=False):
                    label_img = label_img + 1
                    label_img = label_img.astype(np.uint8)

                label_img = Image.fromarray(label_img, 'P')
                label_path = os.path.join(
                    out_dir, 'label/{}.png'.format(meta_list[i]['filename']))
                Log.info('Label Path: {}'.format(label_path))
                ImageHelper.save(label_img, label_path)

    def ss_test(self, in_data_dict):
        data_dict = self.blob_helper.get_blob(in_data_dict, scale=1.0)
        results = self._predict(data_dict)
        return results

    def ms_test(self, in_data_dict, params_dict):
        total_logits = [
            np.zeros((meta['ori_img_size'][1], meta['ori_img_size'][0],
                      self.configer.get('data', 'num_classes')), np.float32)
            for meta in DCHelper.tolist(in_data_dict['meta'])
        ]
        for scale in params_dict['scale_search']:
            data_dict = self.blob_helper.get_blob(in_data_dict, scale=scale)
            results = self._predict(data_dict)
            for i in range(len(total_logits)):
                total_logits[i] += results[i]

        for scale in params_dict['scale_search']:
            data_dict = self.blob_helper.get_blob(in_data_dict,
                                                  scale=scale,
                                                  flip=True)
            results = self._predict(data_dict)
            for i in range(len(total_logits)):
                total_logits[i] += results[i][:, ::-1]

        return total_logits

    def sscrop_test(self, in_data_dict, params_dict):
        data_dict = self.blob_helper.get_blob(in_data_dict, scale=1.0)
        if any(image.size()[2] < params_dict['crop_size'][0]
               or image.size()[1] < params_dict['crop_size'][1]
               for image in DCHelper.tolist(data_dict['img'])):
            results = self._predict(data_dict)
        else:
            results = self._crop_predict(data_dict, params_dict['crop_size'],
                                         params_dict['crop_stride_ratio'])

        return results

    def mscrop_test(self, in_data_dict, params_dict):
        total_logits = [
            np.zeros((meta['ori_img_size'][1], meta['ori_img_size'][0],
                      self.configer.get('data', 'num_classes')), np.float32)
            for meta in DCHelper.tolist(in_data_dict['meta'])
        ]
        for scale in params_dict['scale_search']:
            data_dict = self.blob_helper.get_blob(in_data_dict, scale=scale)
            if any(image.size()[2] < params_dict['crop_size'][0]
                   or image.size()[1] < params_dict['crop_size'][1]
                   for image in DCHelper.tolist(data_dict['img'])):
                results = self._predict(data_dict)
            else:
                results = self._crop_predict(data_dict,
                                             params_dict['crop_size'],
                                             params_dict['crop_stride_ratio'])

            for i in range(len(total_logits)):
                total_logits[i] += results[i]

        for scale in params_dict['scale_search']:
            data_dict = self.blob_helper.get_blob(in_data_dict,
                                                  scale=scale,
                                                  flip=True)
            if any(image.size()[2] < params_dict['crop_size'][0]
                   or image.size()[1] < params_dict['crop_size'][1]
                   for image in DCHelper.tolist(data_dict['img'])):
                results = self._predict(data_dict)
            else:
                results = self._crop_predict(data_dict,
                                             params_dict['crop_size'],
                                             params_dict['crop_stride_ratio'])

            for i in range(len(total_logits)):
                total_logits[i] += results[i][:, ::-1]

        return total_logits

    def _crop_predict(self, data_dict, crop_size, crop_stride_ratio):
        split_batch = list()
        height_starts_list = list()
        width_starts_list = list()
        hw_list = list()
        for image in DCHelper.tolist(data_dict['img']):
            height, width = image.size()[1:]
            hw_list.append([height, width])
            np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
            height_starts = self._decide_intersection(height, crop_size[1],
                                                      crop_stride_ratio)
            width_starts = self._decide_intersection(width, crop_size[0],
                                                     crop_stride_ratio)
            split_crops = []
            for height in height_starts:
                for width in width_starts:
                    image_crop = np_image[height:height + crop_size[1],
                                          width:width + crop_size[0]]
                    split_crops.append(image_crop[np.newaxis, :])

            height_starts_list.append(height_starts)
            width_starts_list.append(width_starts)
            split_crops = np.concatenate(
                split_crops,
                axis=0)  # (n, crop_image_size, crop_image_size, 3)
            inputs = torch.from_numpy(split_crops).permute(0, 3, 1,
                                                           2).to(self.device)
            split_batch.extend(list(inputs))

        out_list = list()
        with torch.no_grad():
            results = self.seg_net(
                dict(img=DCHelper.todc(
                    split_batch, stack=True, samples_per_gpu=True)))
            for res in results:
                out_list.append(res['out'].permute(0, 2, 3, 1).cpu().numpy())

        total_logits = [
            np.zeros((hw[0], hw[1], self.configer.get('data', 'num_classes')),
                     np.float32) for hw in hw_list
        ]
        count_predictions = [
            np.zeros((hw[0], hw[1], self.configer.get('data', 'num_classes')),
                     np.float32) for hw in hw_list
        ]
        for i in range(len(height_starts_list)):
            index = 0
            for height in height_starts_list[i]:
                for width in width_starts_list[i]:
                    total_logits[i][height:height + crop_size[1], width:width +
                                    crop_size[0]] += out_list[i][index]
                    count_predictions[i][height:height + crop_size[1],
                                         width:width + crop_size[0]] += 1
                    index += 1

        for i in range(len(total_logits)):
            total_logits[i] /= count_predictions[i]

        for i, meta in enumerate(DCHelper.tolist(data_dict['meta'])):
            total_logits[i] = cv2.resize(
                total_logits[i][:meta['border_hw'][0], :meta['border_hw'][1]],
                (meta['ori_img_size'][0], meta['ori_img_size'][1]),
                interpolation=cv2.INTER_CUBIC)

        return total_logits

    def _decide_intersection(self, total_length, crop_length,
                             crop_stride_ratio):
        stride = int(crop_length *
                     crop_stride_ratio)  # set the stride as the paper do
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride * i)

        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length -
                                    crop_length)  # must cover the total image

        return cropped_starting

    def _predict(self, data_dict):
        with torch.no_grad():
            total_logits = list()
            results = self.seg_net(data_dict)
            for res in results:
                total_logits.append(res['out'].squeeze(0).permute(
                    1, 2, 0).cpu().numpy())

            for i, meta in enumerate(DCHelper.tolist(data_dict['meta'])):
                total_logits[i] = cv2.resize(
                    total_logits[i]
                    [:meta['border_hw'][0], :meta['border_hw'][1]],
                    (meta['ori_img_size'][0], meta['ori_img_size'][1]),
                    interpolation=cv2.INTER_CUBIC)

        return total_logits

    def __relabel(self, label_map):
        height, width = label_map.shape
        label_dst = np.zeros((height, width), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            label_dst[label_map == i] = self.configer.get(
                'data', 'label_list')[i]

        label_dst = np.array(label_dst, dtype=np.uint8)

        return label_dst