Exemplo n.º 1
0
    def get_blob(self, data_dict, scale=None, flip=False):
        assert scale is not None

        img_list, meta_list = [], []
        for image, meta in zip(DCHelper.tolist(data_dict['img']), DCHelper.tolist(data_dict['meta'])):
            c, h, w = image.size()
            border_hw = [int(h*scale), int(w*scale)]
            meta['border_hw'] = border_hw
            image = TensorHelper.resize(image, border_hw, mode='bilinear', align_corners=True)
            if flip:
                image = image.flip([2])

            if self.configer.get('test.fit_stride', default=0) > 0:
                stride = self.configer.get('test', 'fit_stride')

                pad_w = 0 if (border_hw[1] % stride == 0) else stride - (border_hw[1] % stride)  # right
                pad_h = 0 if (border_hw[0] % stride == 0) else stride - (border_hw[0] % stride)  # down

                expand_image = torch.zeros((c, border_hw[0] + pad_h, border_hw[1] + pad_w)).to(image.device)
                expand_image[:, 0:border_hw[0], 0:border_hw[1]] = image
                image = expand_image

            img_list.append(image)
            meta_list.append(meta)

        new_data_dict = dict(
            img=DCHelper.todc(img_list, stack=True, samples_per_gpu=True),
            meta=DCHelper.todc(meta_list, samples_per_gpu=True, cpu_only=True)
        )
        return new_data_dict
Exemplo n.º 2
0
    def _crop_predict(self, data_dict, crop_size, crop_stride_ratio):
        split_batch = list()
        height_starts_list = list()
        width_starts_list = list()
        hw_list = list()
        for image in DCHelper.tolist(data_dict['img']):
            height, width = image.size()[1:]
            hw_list.append([height, width])
            np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
            height_starts = self._decide_intersection(height, crop_size[1], crop_stride_ratio)
            width_starts = self._decide_intersection(width, crop_size[0], crop_stride_ratio)
            split_crops = []
            for height in height_starts:
                for width in width_starts:
                    image_crop = np_image[height:height + crop_size[1], width:width + crop_size[0]]
                    split_crops.append(image_crop[np.newaxis, :])

            height_starts_list.append(height_starts)
            width_starts_list.append(width_starts)
            split_crops = np.concatenate(split_crops, axis=0)  # (n, crop_image_size, crop_image_size, 3)
            inputs = torch.from_numpy(split_crops).permute(0, 3, 1, 2).to(self.device)
            split_batch.append(inputs)

        assert len(split_batch) == torch.cuda.device_count(), 'Only support one image per gpu.'
        out_list = list()
        with torch.no_grad():
            results = self.seg_net(dict(img=DCHelper.todc(split_batch, stack=False, samples_per_gpu=True, concat=True)))
            results = results if isinstance(results, (list, tuple)) else [results]
            for res in results:
                out_list.append(res['out'].permute(0, 2, 3, 1).cpu().numpy())

        total_logits = [np.zeros((hw[0], hw[1],
                                  self.configer.get('data', 'num_classes')), np.float32) for hw in hw_list]
        count_predictions = [np.zeros((hw[0], hw[1],
                                       self.configer.get('data', 'num_classes')), np.float32) for hw in hw_list]
        for i in range(len(height_starts_list)):
            index = 0
            for height in height_starts_list[i]:
                for width in width_starts_list[i]:
                    total_logits[i][height:height+crop_size[1], width:width+crop_size[0]] += out_list[i][index]
                    count_predictions[i][height:height+crop_size[1], width:width+crop_size[0]] += 1
                    index += 1

        for i in range(len(total_logits)):
            total_logits[i] /= count_predictions[i]

        for i, meta in enumerate(DCHelper.tolist(data_dict['meta'])):
            total_logits[i] = cv2.resize(total_logits[i][:meta['border_wh'][1], :meta['border_wh'][0]],
                                         tuple(meta['ori_img_size']), interpolation=cv2.INTER_CUBIC)

        return total_logits
    def _crop_predict(self, data_dict, crop_size, crop_stride_ratio):
        split_batch = list()
        height_starts_list = list()
        width_starts_list = list()
        hw_list = list()
        for image in DCHelper.tolist(data_dict['img']):
            height, width = image.size()[1:]
            hw_list.append([height, width])
            np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
            height_starts = self._decide_intersection(height, crop_size[1],
                                                      crop_stride_ratio)
            width_starts = self._decide_intersection(width, crop_size[0],
                                                     crop_stride_ratio)
            split_crops = []
            for height in height_starts:
                for width in width_starts:
                    image_crop = np_image[height:height + crop_size[1],
                                          width:width + crop_size[0]]
                    split_crops.append(image_crop[np.newaxis, :])

            height_starts_list.append(height_starts)
            width_starts_list.append(width_starts)
            split_crops = np.concatenate(
                split_crops,
                axis=0)  # (n, crop_image_size, crop_image_size, 3)
            inputs = torch.from_numpy(split_crops).permute(0, 3, 1,
                                                           2).to(self.device)
            split_batch.extend(list(inputs))

        out_list = list()
        print(data_dict['img'].data[0].shape, len(split_batch))
        with torch.no_grad():
            _len_base = 64
            if len(split_batch) > _len_base:
                #print('my_test')
                results = []
                for i in range(0, len(split_batch) - 1, _len_base):
                    #print(i)
                    torch.cuda.empty_cache()
                    tmp_results = self.seg_net(
                        dict(img=DCHelper.todc(
                            split_batch[i:min(i +
                                              _len_base, len(split_batch))],
                            stack=True,
                            samples_per_gpu=True)))
                    results.append(
                        torch.cat([
                            ele['out'].detach().cpu().permute(0, 2, 3, 1)
                            for ele in tmp_results
                        ]))
                    del tmp_results

                results = torch.cat(results)
                results = results.view(len(height_starts_list), -1,
                                       results.shape[1], results.shape[2],
                                       results.shape[3])
                out_list = [
                    results[i].numpy() for i in range(len(height_starts_list))
                ]
            else:
                results = self.seg_net(
                    dict(img=DCHelper.todc(
                        split_batch, stack=True, samples_per_gpu=True)))
                for res in results:
                    out_list.append(res['out'].detach().permute(
                        0, 2, 3, 1).cpu().numpy())

        total_logits = [
            np.zeros((hw[0], hw[1], self.configer.get('data', 'num_classes')),
                     np.float32) for hw in hw_list
        ]
        count_predictions = [
            np.zeros((hw[0], hw[1], self.configer.get('data', 'num_classes')),
                     np.float32) for hw in hw_list
        ]
        for i in range(len(height_starts_list)):
            index = 0
            for height in height_starts_list[i]:
                for width in width_starts_list[i]:
                    total_logits[i][height:height + crop_size[1], width:width +
                                    crop_size[0]] += out_list[i][index]
                    count_predictions[i][height:height + crop_size[1],
                                         width:width + crop_size[0]] += 1
                    index += 1

        for i in range(len(total_logits)):
            total_logits[i] /= count_predictions[i]

        for i, meta in enumerate(DCHelper.tolist(data_dict['meta'])):
            total_logits[i] = cv2.resize(
                total_logits[i][:meta['border_wh'][1], :meta['border_wh'][0]],
                tuple(meta['ori_img_size']),
                interpolation=cv2.INTER_CUBIC)

        return total_logits