def get_blob(self, data_dict, scale=None, flip=False): assert scale is not None img_list, meta_list = [], [] for image, meta in zip(DCHelper.tolist(data_dict['img']), DCHelper.tolist(data_dict['meta'])): c, h, w = image.size() border_hw = [int(h*scale), int(w*scale)] meta['border_hw'] = border_hw image = TensorHelper.resize(image, border_hw, mode='bilinear', align_corners=True) if flip: image = image.flip([2]) if self.configer.get('test.fit_stride', default=0) > 0: stride = self.configer.get('test', 'fit_stride') pad_w = 0 if (border_hw[1] % stride == 0) else stride - (border_hw[1] % stride) # right pad_h = 0 if (border_hw[0] % stride == 0) else stride - (border_hw[0] % stride) # down expand_image = torch.zeros((c, border_hw[0] + pad_h, border_hw[1] + pad_w)).to(image.device) expand_image[:, 0:border_hw[0], 0:border_hw[1]] = image image = expand_image img_list.append(image) meta_list.append(meta) new_data_dict = dict( img=DCHelper.todc(img_list, stack=True, samples_per_gpu=True), meta=DCHelper.todc(meta_list, samples_per_gpu=True, cpu_only=True) ) return new_data_dict
def _crop_predict(self, data_dict, crop_size, crop_stride_ratio): split_batch = list() height_starts_list = list() width_starts_list = list() hw_list = list() for image in DCHelper.tolist(data_dict['img']): height, width = image.size()[1:] hw_list.append([height, width]) np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy() height_starts = self._decide_intersection(height, crop_size[1], crop_stride_ratio) width_starts = self._decide_intersection(width, crop_size[0], crop_stride_ratio) split_crops = [] for height in height_starts: for width in width_starts: image_crop = np_image[height:height + crop_size[1], width:width + crop_size[0]] split_crops.append(image_crop[np.newaxis, :]) height_starts_list.append(height_starts) width_starts_list.append(width_starts) split_crops = np.concatenate(split_crops, axis=0) # (n, crop_image_size, crop_image_size, 3) inputs = torch.from_numpy(split_crops).permute(0, 3, 1, 2).to(self.device) split_batch.append(inputs) assert len(split_batch) == torch.cuda.device_count(), 'Only support one image per gpu.' out_list = list() with torch.no_grad(): results = self.seg_net(dict(img=DCHelper.todc(split_batch, stack=False, samples_per_gpu=True, concat=True))) results = results if isinstance(results, (list, tuple)) else [results] for res in results: out_list.append(res['out'].permute(0, 2, 3, 1).cpu().numpy()) total_logits = [np.zeros((hw[0], hw[1], self.configer.get('data', 'num_classes')), np.float32) for hw in hw_list] count_predictions = [np.zeros((hw[0], hw[1], self.configer.get('data', 'num_classes')), np.float32) for hw in hw_list] for i in range(len(height_starts_list)): index = 0 for height in height_starts_list[i]: for width in width_starts_list[i]: total_logits[i][height:height+crop_size[1], width:width+crop_size[0]] += out_list[i][index] count_predictions[i][height:height+crop_size[1], width:width+crop_size[0]] += 1 index += 1 for i in range(len(total_logits)): total_logits[i] /= count_predictions[i] for i, meta in enumerate(DCHelper.tolist(data_dict['meta'])): total_logits[i] = cv2.resize(total_logits[i][:meta['border_wh'][1], :meta['border_wh'][0]], tuple(meta['ori_img_size']), interpolation=cv2.INTER_CUBIC) return total_logits
def _crop_predict(self, data_dict, crop_size, crop_stride_ratio): split_batch = list() height_starts_list = list() width_starts_list = list() hw_list = list() for image in DCHelper.tolist(data_dict['img']): height, width = image.size()[1:] hw_list.append([height, width]) np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy() height_starts = self._decide_intersection(height, crop_size[1], crop_stride_ratio) width_starts = self._decide_intersection(width, crop_size[0], crop_stride_ratio) split_crops = [] for height in height_starts: for width in width_starts: image_crop = np_image[height:height + crop_size[1], width:width + crop_size[0]] split_crops.append(image_crop[np.newaxis, :]) height_starts_list.append(height_starts) width_starts_list.append(width_starts) split_crops = np.concatenate( split_crops, axis=0) # (n, crop_image_size, crop_image_size, 3) inputs = torch.from_numpy(split_crops).permute(0, 3, 1, 2).to(self.device) split_batch.extend(list(inputs)) out_list = list() print(data_dict['img'].data[0].shape, len(split_batch)) with torch.no_grad(): _len_base = 64 if len(split_batch) > _len_base: #print('my_test') results = [] for i in range(0, len(split_batch) - 1, _len_base): #print(i) torch.cuda.empty_cache() tmp_results = self.seg_net( dict(img=DCHelper.todc( split_batch[i:min(i + _len_base, len(split_batch))], stack=True, samples_per_gpu=True))) results.append( torch.cat([ ele['out'].detach().cpu().permute(0, 2, 3, 1) for ele in tmp_results ])) del tmp_results results = torch.cat(results) results = results.view(len(height_starts_list), -1, results.shape[1], results.shape[2], results.shape[3]) out_list = [ results[i].numpy() for i in range(len(height_starts_list)) ] else: results = self.seg_net( dict(img=DCHelper.todc( split_batch, stack=True, samples_per_gpu=True))) for res in results: out_list.append(res['out'].detach().permute( 0, 2, 3, 1).cpu().numpy()) total_logits = [ np.zeros((hw[0], hw[1], self.configer.get('data', 'num_classes')), np.float32) for hw in hw_list ] count_predictions = [ np.zeros((hw[0], hw[1], self.configer.get('data', 'num_classes')), np.float32) for hw in hw_list ] for i in range(len(height_starts_list)): index = 0 for height in height_starts_list[i]: for width in width_starts_list[i]: total_logits[i][height:height + crop_size[1], width:width + crop_size[0]] += out_list[i][index] count_predictions[i][height:height + crop_size[1], width:width + crop_size[0]] += 1 index += 1 for i in range(len(total_logits)): total_logits[i] /= count_predictions[i] for i, meta in enumerate(DCHelper.tolist(data_dict['meta'])): total_logits[i] = cv2.resize( total_logits[i][:meta['border_wh'][1], :meta['border_wh'][0]], tuple(meta['ori_img_size']), interpolation=cv2.INTER_CUBIC) return total_logits