def get_blob(self, data_dict, scale=None, flip=False):
        assert scale is not None

        img_list, meta_list = [], []
        for image, meta in zip(DCHelper.tolist(data_dict['img']),
                               DCHelper.tolist(data_dict['meta'])):
            c, h, w = image.size()
            border_hw = [int(h * scale), int(w * scale)]
            meta['border_wh'] = border_hw[::-1]
            image = TensorHelper.resize(image,
                                        border_hw,
                                        mode='bilinear',
                                        align_corners=True)
            if flip:
                image = image.flip([2])

            if self.configer.get('test.fit_stride', default=0) > 0:
                stride = self.configer.get('test', 'fit_stride')

                pad_w = 0 if (
                    border_hw[1] %
                    stride == 0) else stride - (border_hw[1] % stride)  # right
                pad_h = 0 if (
                    border_hw[0] %
                    stride == 0) else stride - (border_hw[0] % stride)  # down

                expand_image = torch.zeros(
                    (c, border_hw[0] + pad_h,
                     border_hw[1] + pad_w)).to(image.device)
                expand_image[:, 0:border_hw[0], 0:border_hw[1]] = image
                image = expand_image

            img_list.append(image)
            meta_list.append(meta)

        new_data_dict = dict(img=DCHelper.todc(img_list,
                                               stack=True,
                                               samples_per_gpu=True),
                             meta=DCHelper.todc(meta_list,
                                                samples_per_gpu=True,
                                                cpu_only=True))
        return new_data_dict
Exemplo n.º 2
0
    def _crop_predict(self, data_dict, crop_size, crop_stride_ratio):
        split_batch = list()
        height_starts_list = list()
        width_starts_list = list()
        hw_list = list()
        for image in DCHelper.tolist(data_dict['img']):
            height, width = image.size()[1:]
            hw_list.append([height, width])
            np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
            height_starts = self._decide_intersection(height, crop_size[1], crop_stride_ratio)
            width_starts = self._decide_intersection(width, crop_size[0], crop_stride_ratio)
            split_crops = []
            for height in height_starts:
                for width in width_starts:
                    image_crop = np_image[height:height + crop_size[1], width:width + crop_size[0]]
                    split_crops.append(image_crop[np.newaxis, :])

            height_starts_list.append(height_starts)
            width_starts_list.append(width_starts)
            split_crops = np.concatenate(split_crops, axis=0)  # (n, crop_image_size, crop_image_size, 3)
            inputs = torch.from_numpy(split_crops).permute(0, 3, 1, 2).to(self.device)
            split_batch.append(inputs)

        assert len(split_batch) == torch.cuda.device_count(), 'Only support one image per gpu.'
        out_list = list()
        with torch.no_grad():
            results = self.seg_net(dict(img=DCHelper.todc(split_batch, stack=False, samples_per_gpu=True, concat=True)))
            results = results if isinstance(results, (list, tuple)) else [results]
            for res in results:
                out_list.append(res['out'].permute(0, 2, 3, 1).cpu().numpy())

        total_logits = [np.zeros((hw[0], hw[1],
                                  self.configer.get('data', 'num_classes')), np.float32) for hw in hw_list]
        count_predictions = [np.zeros((hw[0], hw[1],
                                       self.configer.get('data', 'num_classes')), np.float32) for hw in hw_list]
        for i in range(len(height_starts_list)):
            index = 0
            for height in height_starts_list[i]:
                for width in width_starts_list[i]:
                    total_logits[i][height:height + crop_size[1], width:width + crop_size[0]] += out_list[i][index]
                    count_predictions[i][height:height + crop_size[1], width:width + crop_size[0]] += 1
                    index += 1

        for i in range(len(total_logits)):
            total_logits[i] /= count_predictions[i]

        for i, meta in enumerate(DCHelper.tolist(data_dict['meta'])):
            total_logits[i] = cv2.resize(total_logits[i][:meta['border_wh'][1], :meta['border_wh'][0]],
                                         tuple(meta['ori_img_size']), interpolation=cv2.INTER_CUBIC)

        return total_logits