def test_single_gpu_test_kie_novisual(cfg_file): curr_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) config_file = os.path.join(curr_dir, cfg_file) cfg = Config.fromfile(config_file) meta_keys = list(cfg.data.test.pipeline[-1]['meta_keys']) must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape'] for key in must_keys: meta_keys.append(key) cfg.data.test.pipeline[-1]['meta_keys'] = tuple(meta_keys) with tempfile.TemporaryDirectory() as tmpdirname: out_dir = osp.join(tmpdirname, 'tmp') model, data_loader = gene_sdmgr_model_dataloader(cfg, out_dir, curr_dir, empty_img=True) results = single_gpu_test(model, data_loader, out_dir=out_dir, is_kie=True) assert check_argument.is_type_list(results, dict) model, data_loader = gene_sdmgr_model_dataloader( cfg, out_dir, curr_dir) results = single_gpu_test(model, data_loader, out_dir=out_dir, is_kie=True) assert check_argument.is_type_list(results, dict)
def get_boundary(self, edges, scores, text_comps, img_metas, rescale): """Compute text boundaries via post processing. Args: edges (ndarray): The edge array of shape N * 2, each row is a pair of text component indices that makes up an edge in graph. scores (ndarray): The edge score array. text_comps (ndarray): The text components. img_metas (list[dict]): The image meta infos. rescale (bool): Rescale boundaries to the original image resolution. Returns: results (dict): The result dict. """ assert check_argument.is_type_list(img_metas, dict) assert isinstance(rescale, bool) boundaries = [] if edges is not None: boundaries = decode(decoding_type='drrg', edges=edges, scores=scores, text_comps=text_comps, link_thr=self.link_thr) if rescale: boundaries = self.resize_boundary( boundaries, 1.0 / self.downsample_ratio / img_metas[0]['scale_factor']) results = dict(boundary_result=boundaries) return results
def get_boundary(self, score_maps, img_metas, rescale): """Compute text boundaries via post processing. Args: score_maps (Tensor): The text score map. img_metas (dict): The image meta info. rescale (bool): Rescale boundaries to the original image resolution if true, and keep the score_maps resolution if false. Returns: dict: A dict where boundary results are stored in ``boundary_result``. """ assert check_argument.is_type_list(img_metas, dict) assert isinstance(rescale, bool) score_maps = score_maps.squeeze() boundaries = decode(decoding_type=self.decoding_type, preds=score_maps, text_repr_type=self.text_repr_type) if rescale: boundaries = self.resize_boundary( boundaries, 1.0 / self.downsample_ratio / img_metas[0]['scale_factor']) results = dict(boundary_result=boundaries, filename=img_metas[0]['filename']) return results
def trace_boundary(char_boxes): """Trace the boundary point of text. Args: char_boxes (list[ndarray]): The char boxes for one text. Each element is 4x2 ndarray. Returns: boundary (ndarray): The boundary point sets with size nx2. """ assert check_argument.is_type_list(char_boxes, np.ndarray) # from top left to to right p_top = [box[0:2] for box in char_boxes] # from bottom right to bottom left p_bottom = [ char_boxes[idx][[2, 3], :] for idx in range(len(char_boxes) - 1, -1, -1) ] p = p_top + p_bottom boundary = np.concatenate(p).astype(int) return boundary
def __init__( self, in_channels, out_channels, text_repr_type='poly', # 'poly' or 'quad' downsample_ratio=0.25, loss=dict(type='PANLoss'), train_cfg=None, test_cfg=None): super().__init__() assert check_argument.is_type_list(in_channels, int) assert isinstance(out_channels, int) assert text_repr_type in ['poly', 'quad'] assert 0 <= downsample_ratio <= 1 self.loss_module = build_loss(loss) self.in_channels = in_channels self.out_channels = out_channels self.text_repr_type = text_repr_type self.train_cfg = train_cfg self.test_cfg = test_cfg self.downsample_ratio = downsample_ratio if loss['type'] == 'PANLoss': self.decoding_type = 'pan' elif loss['type'] == 'PSELoss': self.decoding_type = 'pse' else: type = loss['type'] raise NotImplementedError(f'unsupported loss type {type}.') self.out_conv = nn.Conv2d(in_channels=np.sum(np.array(in_channels)), out_channels=out_channels, kernel_size=1) self.init_weights()
def test_single_gpu_test_kie(cfg_file): curr_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) config_file = os.path.join(curr_dir, cfg_file) cfg = Config.fromfile(config_file) with tempfile.TemporaryDirectory() as tmpdirname: out_dir = osp.join(tmpdirname, 'tmp') model, data_loader = gene_sdmgr_model_dataloader( cfg, out_dir, curr_dir) results = single_gpu_test(model, data_loader, out_dir=out_dir, is_kie=True) assert check_argument.is_type_list(results, dict)
def test_single_gpu_test_det(cfg_file): curr_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) config_file = os.path.join(curr_dir, cfg_file) cfg = Config.fromfile(config_file) model = build_model(cfg) img_prefix = 'data/toy_dataset/imgs' ann_file = 'data/toy_dataset/instances_test.json' data_loader = generate_sample_dataloader(cfg, curr_dir, img_prefix, ann_file) with tempfile.TemporaryDirectory() as tmpdirname: out_dir = osp.join(tmpdirname, 'tmp') results = single_gpu_test(model, data_loader, out_dir=out_dir) assert check_argument.is_type_list(results, dict)
def generate_kernels(self, resize_shape, pad_shape, char_boxes, char_inds, shrink_ratio=0.5, binary=True): """Generate char instance kernels for one shrink ratio. Args: resize_shape (tuple(int, int)): Image size (height, width) after resizing. pad_shape (tuple(int, int)): Image size (height, width) after padding. char_boxes (list[list[float]]): The list of char polygons. char_inds (list[int]): List of char indexes. shrink_ratio (float): The shrink ratio of kernel. binary (bool): If True, return binary ndarray containing 0 & 1 only. Returns: char_kernel (ndarray): The text kernel mask of (height, width). """ assert isinstance(resize_shape, tuple) assert isinstance(pad_shape, tuple) assert check_argument.is_2dlist(char_boxes) assert check_argument.is_type_list(char_inds, int) assert isinstance(shrink_ratio, float) assert isinstance(binary, bool) char_kernel = np.zeros(pad_shape, dtype=np.int32) char_kernel[:resize_shape[0], resize_shape[1]:] = self.pad_val for i, char_box in enumerate(char_boxes): if self.box_type == 'char_rects': poly = self.shrink_char_rect(char_box, shrink_ratio) elif self.box_type == 'char_quads': poly = self.shrink_char_quad(char_box, shrink_ratio) fill_value = 1 if binary else char_inds[i] cv2.fillConvexPoly(char_kernel, poly.astype(np.int32), (fill_value)) return char_kernel
def __init__(self, in_channels, out_channels, downsample_ratio=0.25, loss=dict(type='PANLoss'), postprocessor=dict( type='PANPostprocessor', text_repr_type='poly'), train_cfg=None, test_cfg=None, init_cfg=dict( type='Normal', mean=0, std=0.01, override=dict(name='out_conv')), **kwargs): old_keys = ['text_repr_type', 'decoding_type'] for key in old_keys: if kwargs.get(key, None): postprocessor[key] = kwargs.get(key) warnings.warn( f'{key} is deprecated, please specify ' 'it in postprocessor config dict. See ' 'https://github.com/open-mmlab/mmocr/pull/640' ' for details.', UserWarning) BaseModule.__init__(self, init_cfg=init_cfg) HeadMixin.__init__(self, loss, postprocessor) assert check_argument.is_type_list(in_channels, int) assert isinstance(out_channels, int) assert 0 <= downsample_ratio <= 1 self.in_channels = in_channels self.out_channels = out_channels self.downsample_ratio = downsample_ratio self.train_cfg = train_cfg self.test_cfg = test_cfg self.out_conv = nn.Conv2d( in_channels=np.sum(np.array(in_channels)), out_channels=out_channels, kernel_size=1)
def sort_points(points): """Sort arbitory points in clockwise order. Reference: https://stackoverflow.com/a/6989383. Args: points (list[ndarray] or ndarray or list[list]): A list of unsorted boundary points. Returns: list[ndarray]: A list of points sorted in clockwise order. """ assert is_type_list(points, np.ndarray) or isinstance(points, np.ndarray) \ or is_2dlist(points) points = np.array(points) center = np.mean(points, axis=0) def cmp(a, b): oa = a - center ob = b - center # Some corner cases if oa[0] >= 0 and ob[0] < 0: return 1 if oa[0] < 0 and ob[0] >= 0: return -1 prod = np.cross(oa, ob) if prod > 0: return 1 if prod < 0: return -1 # a, b are on the same line from the center return 1 if (oa**2).sum() < (ob**2).sum() else -1 return sorted(points, key=functools.cmp_to_key(cmp))
def bitmasks2tensor(self, bitmasks, target_sz): """Convert Bitmasks to tensor. Args: bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is for one img. target_sz (tuple(int, int)): The target tensor of size :math:`(H, W)`. Returns: list[Tensor]: The list of kernel tensors. Each element stands for one kernel level. """ assert check_argument.is_type_list(bitmasks, BitmapMasks) assert isinstance(target_sz, tuple) batch_size = len(bitmasks) num_masks = len(bitmasks[0]) results = [] for level_inx in range(num_masks): kernel = [] for batch_inx in range(batch_size): mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx]) # hxw mask_sz = mask.shape # left, right, top, bottom pad = [ 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0] ] mask = F.pad(mask, pad, mode='constant', value=0) kernel.append(mask) kernel = torch.stack(kernel) results.append(kernel) return results
def forward(self, preds, downsample_ratio, gt_text_mask, gt_center_region_mask, gt_mask, gt_top_height_map, gt_bot_height_map, gt_sin_map, gt_cos_map): """Compute Drrg loss. Args: preds (tuple(Tensor)): The first is the prediction map with shape :math:`(N, C_{out}, H, W)`. The second is prediction from GCN module, with shape :math:`(N, 2)`. The third is ground-truth label with shape :math:`(N, 8)`. downsample_ratio (float): The downsample ratio. gt_text_mask (list[BitmapMasks]): Text mask. gt_center_region_mask (list[BitmapMasks]): Center region mask. gt_mask (list[BitmapMasks]): Effective mask. gt_top_height_map (list[BitmapMasks]): Top height map. gt_bot_height_map (list[BitmapMasks]): Bottom height map. gt_sin_map (list[BitmapMasks]): Sinusoid map. gt_cos_map (list[BitmapMasks]): Cosine map. Returns: dict: A loss dict with ``loss_text``, ``loss_center``, ``loss_height``, ``loss_sin``, ``loss_cos``, and ``loss_gcn``. """ assert isinstance(preds, tuple) assert isinstance(downsample_ratio, float) assert check_argument.is_type_list(gt_text_mask, BitmapMasks) assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks) assert check_argument.is_type_list(gt_mask, BitmapMasks) assert check_argument.is_type_list(gt_top_height_map, BitmapMasks) assert check_argument.is_type_list(gt_bot_height_map, BitmapMasks) assert check_argument.is_type_list(gt_sin_map, BitmapMasks) assert check_argument.is_type_list(gt_cos_map, BitmapMasks) pred_maps, gcn_data = preds pred_text_region = pred_maps[:, 0, :, :] pred_center_region = pred_maps[:, 1, :, :] pred_sin_map = pred_maps[:, 2, :, :] pred_cos_map = pred_maps[:, 3, :, :] pred_top_height_map = pred_maps[:, 4, :, :] pred_bot_height_map = pred_maps[:, 5, :, :] feature_sz = pred_maps.size() device = pred_maps.device # bitmask 2 tensor mapping = { 'gt_text_mask': gt_text_mask, 'gt_center_region_mask': gt_center_region_mask, 'gt_mask': gt_mask, 'gt_top_height_map': gt_top_height_map, 'gt_bot_height_map': gt_bot_height_map, 'gt_sin_map': gt_sin_map, 'gt_cos_map': gt_cos_map } gt = {} for key, value in mapping.items(): gt[key] = value if abs(downsample_ratio - 1.0) < 1e-2: gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) else: gt[key] = [item.rescale(downsample_ratio) for item in gt[key]] gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) if key in ['gt_top_height_map', 'gt_bot_height_map']: gt[key] = [item * downsample_ratio for item in gt[key]] gt[key] = [item.to(device) for item in gt[key]] scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8)) pred_sin_map = pred_sin_map * scale pred_cos_map = pred_cos_map * scale loss_text = self.balance_bce_loss(torch.sigmoid(pred_text_region), gt['gt_text_mask'][0], gt['gt_mask'][0]) text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float() negative_text_mask = ((1 - gt['gt_text_mask'][0]) * gt['gt_mask'][0]).float() loss_center_map = F.binary_cross_entropy( torch.sigmoid(pred_center_region), gt['gt_center_region_mask'][0].float(), reduction='none') if int(text_mask.sum()) > 0: loss_center_positive = torch.sum( loss_center_map * text_mask) / torch.sum(text_mask) else: loss_center_positive = torch.tensor(0.0, device=device) loss_center_negative = torch.sum( loss_center_map * negative_text_mask) / torch.sum(negative_text_mask) loss_center = loss_center_positive + 0.5 * loss_center_negative center_mask = (gt['gt_center_region_mask'][0] * gt['gt_mask'][0]).float() if int(center_mask.sum()) > 0: map_sz = pred_top_height_map.size() ones = torch.ones(map_sz, dtype=torch.float, device=device) loss_top = F.smooth_l1_loss(pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2), ones, reduction='none') loss_bot = F.smooth_l1_loss(pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2), ones, reduction='none') gt_height = (gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0]) loss_height = torch.sum( (torch.log(gt_height + 1) * (loss_top + loss_bot)) * center_mask) / torch.sum(center_mask) loss_sin = torch.sum( F.smooth_l1_loss( pred_sin_map, gt['gt_sin_map'][0], reduction='none') * center_mask) / torch.sum(center_mask) loss_cos = torch.sum( F.smooth_l1_loss( pred_cos_map, gt['gt_cos_map'][0], reduction='none') * center_mask) / torch.sum(center_mask) else: loss_height = torch.tensor(0.0, device=device) loss_sin = torch.tensor(0.0, device=device) loss_cos = torch.tensor(0.0, device=device) loss_gcn = self.gcn_loss(gcn_data) results = dict(loss_text=loss_text, loss_center=loss_center, loss_height=loss_height, loss_sin=loss_sin, loss_cos=loss_cos, loss_gcn=loss_gcn) return results
def forward(self, score_maps, downsample_ratio, gt_kernels, gt_mask): """Compute PSENet loss. Args: score_maps (tensor): The output tensor with size of Nx6xHxW. gt_kernels (list[BitmapMasks]): The kernel list with each element being the text kernel mask for one img. gt_mask (list[BitmapMasks]): The effective mask list with each element being the effective mask fo one img. downsample_ratio (float): The downsample ratio between score_maps and the input img. Returns: results (dict): The loss. """ assert check_argument.is_type_list(gt_kernels, BitmapMasks) assert check_argument.is_type_list(gt_mask, BitmapMasks) assert isinstance(downsample_ratio, float) losses = [] pred_texts = score_maps[:, 0, :, :] pred_kernels = score_maps[:, 1:, :, :] feature_sz = score_maps.size() gt_kernels = [item.rescale(downsample_ratio) for item in gt_kernels] gt_kernels = self.bitmasks2tensor(gt_kernels, feature_sz[2:]) gt_kernels = [item.to(score_maps.device) for item in gt_kernels] gt_mask = [item.rescale(downsample_ratio) for item in gt_mask] gt_mask = self.bitmasks2tensor(gt_mask, feature_sz[2:]) gt_mask = [item.to(score_maps.device) for item in gt_mask] # compute text loss sampled_masks_text = self.ohem_batch(pred_texts.detach(), gt_kernels[0], gt_mask[0]) loss_texts = self.dice_loss_with_logits(pred_texts, gt_kernels[0], sampled_masks_text) losses.append(self.alpha * loss_texts) # compute kernel loss if self.kernel_sample_type == 'hard': sampled_masks_kernel = (gt_kernels[0] > 0.5).float() * ( gt_mask[0].float()) elif self.kernel_sample_type == 'adaptive': sampled_masks_kernel = (pred_texts > 0).float() * ( gt_mask[0].float()) else: raise NotImplementedError num_kernel = pred_kernels.shape[1] assert num_kernel == len(gt_kernels) - 1 loss_list = [] for idx in range(num_kernel): loss_kernels = self.dice_loss_with_logits( pred_kernels[:, idx, :, :], gt_kernels[1 + idx], sampled_masks_kernel) loss_list.append(loss_kernels) losses.append((1 - self.alpha) * sum(loss_list) / len(loss_list)) if self.reduction == 'mean': losses = [item.mean() for item in losses] elif self.reduction == 'sum': losses = [item.sum() for item in losses] else: raise NotImplementedError results = dict(loss_text=losses[0], loss_kernel=losses[1]) return results
def forward(self, pred_maps, downsample_ratio, gt_text_mask, gt_center_region_mask, gt_mask, gt_radius_map, gt_sin_map, gt_cos_map): """ Args: pred_maps (Tensor): The prediction map of shape :math:`(N, 5, H, W)`, where each dimension is the map of "text_region", "center_region", "sin_map", "cos_map", and "radius_map" respectively. downsample_ratio (float): Downsample ratio. gt_text_mask (list[BitmapMasks]): Gold text masks. gt_center_region_mask (list[BitmapMasks]): Gold center region masks. gt_mask (list[BitmapMasks]): Gold general masks. gt_radius_map (list[BitmapMasks]): Gold radius maps. gt_sin_map (list[BitmapMasks]): Gold sin maps. gt_cos_map (list[BitmapMasks]): Gold cos maps. Returns: dict: A loss dict with ``loss_text``, ``loss_center``, ``loss_radius``, ``loss_sin`` and ``loss_cos``. """ assert isinstance(downsample_ratio, float) assert check_argument.is_type_list(gt_text_mask, BitmapMasks) assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks) assert check_argument.is_type_list(gt_mask, BitmapMasks) assert check_argument.is_type_list(gt_radius_map, BitmapMasks) assert check_argument.is_type_list(gt_sin_map, BitmapMasks) assert check_argument.is_type_list(gt_cos_map, BitmapMasks) pred_text_region = pred_maps[:, 0, :, :] pred_center_region = pred_maps[:, 1, :, :] pred_sin_map = pred_maps[:, 2, :, :] pred_cos_map = pred_maps[:, 3, :, :] pred_radius_map = pred_maps[:, 4, :, :] feature_sz = pred_maps.size() device = pred_maps.device # bitmask 2 tensor mapping = { 'gt_text_mask': gt_text_mask, 'gt_center_region_mask': gt_center_region_mask, 'gt_mask': gt_mask, 'gt_radius_map': gt_radius_map, 'gt_sin_map': gt_sin_map, 'gt_cos_map': gt_cos_map } gt = {} for key, value in mapping.items(): gt[key] = value if abs(downsample_ratio - 1.0) < 1e-2: gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) else: gt[key] = [item.rescale(downsample_ratio) for item in gt[key]] gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) if key == 'gt_radius_map': gt[key] = [item * downsample_ratio for item in gt[key]] gt[key] = [item.to(device) for item in gt[key]] scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8)) pred_sin_map = pred_sin_map * scale pred_cos_map = pred_cos_map * scale loss_text = self.balanced_bce_loss( torch.sigmoid(pred_text_region), gt['gt_text_mask'][0], gt['gt_mask'][0]) text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float() loss_center_map = F.binary_cross_entropy( torch.sigmoid(pred_center_region), gt['gt_center_region_mask'][0].float(), reduction='none') if int(text_mask.sum()) > 0: loss_center = torch.sum( loss_center_map * text_mask) / torch.sum(text_mask) else: loss_center = torch.tensor(0.0, device=device) center_mask = (gt['gt_center_region_mask'][0] * gt['gt_mask'][0]).float() if int(center_mask.sum()) > 0: map_sz = pred_radius_map.size() ones = torch.ones(map_sz, dtype=torch.float, device=device) loss_radius = torch.sum( F.smooth_l1_loss( pred_radius_map / (gt['gt_radius_map'][0] + 1e-2), ones, reduction='none') * center_mask) / torch.sum(center_mask) loss_sin = torch.sum( F.smooth_l1_loss( pred_sin_map, gt['gt_sin_map'][0], reduction='none') * center_mask) / torch.sum(center_mask) loss_cos = torch.sum( F.smooth_l1_loss( pred_cos_map, gt['gt_cos_map'][0], reduction='none') * center_mask) / torch.sum(center_mask) else: loss_radius = torch.tensor(0.0, device=device) loss_sin = torch.tensor(0.0, device=device) loss_cos = torch.tensor(0.0, device=device) results = dict( loss_text=loss_text, loss_center=loss_center, loss_radius=loss_radius, loss_sin=loss_sin, loss_cos=loss_cos) return results
def forward(self, preds, downsample_ratio, gt_kernels, gt_mask): """Compute PANet loss. Args: preds (tensor): The output tensor with size of Nx6xHxW. gt_kernels (list[BitmapMasks]): The kernel list with each element being the text kernel mask for one img. gt_mask (list[BitmapMasks]): The effective mask list with each element being the effective mask fo one img. downsample_ratio (float): The downsample ratio between preds and the input img. Returns: results (dict): The loss dictionary. """ assert check_argument.is_type_list(gt_kernels, BitmapMasks) assert check_argument.is_type_list(gt_mask, BitmapMasks) assert isinstance(downsample_ratio, float) pred_texts = preds[:, 0, :, :] pred_kernels = preds[:, 1, :, :] inst_embed = preds[:, 2:, :, :] feature_sz = preds.size() mapping = {'gt_kernels': gt_kernels, 'gt_mask': gt_mask} gt = {} for key, value in mapping.items(): gt[key] = value gt[key] = [item.rescale(downsample_ratio) for item in gt[key]] gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) gt[key] = [item.to(preds.device) for item in gt[key]] loss_aggrs, loss_discrs = self.aggregation_discrimination_loss( gt['gt_kernels'][0], gt['gt_kernels'][1], inst_embed) # compute text loss sampled_mask = self.ohem_batch(pred_texts.detach(), gt['gt_kernels'][0], gt['gt_mask'][0]) loss_texts = self.dice_loss_with_logits(pred_texts, gt['gt_kernels'][0], sampled_mask) # compute kernel loss sampled_masks_kernel = (gt['gt_kernels'][0] > 0.5).float() * (gt['gt_mask'][0].float()) loss_kernels = self.dice_loss_with_logits(pred_kernels, gt['gt_kernels'][1], sampled_masks_kernel) losses = [loss_texts, loss_kernels, loss_aggrs, loss_discrs] if self.reduction == 'mean': losses = [item.mean() for item in losses] elif self.reduction == 'sum': losses = [item.sum() for item in losses] else: raise NotImplementedError coefs = [1, self.alpha, self.beta, self.beta] losses = [item * scale for item, scale in zip(losses, coefs)] results = dict() results.update(loss_text=losses[0], loss_kernel=losses[1], loss_aggregation=losses[2], loss_discrimination=losses[3]) return results
def forward(self, pred_maps, downsample_ratio, gt_text_mask, gt_center_region_mask, gt_mask, gt_radius_map, gt_sin_map, gt_cos_map): assert isinstance(downsample_ratio, float) assert check_argument.is_type_list(gt_text_mask, BitmapMasks) assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks) assert check_argument.is_type_list(gt_mask, BitmapMasks) assert check_argument.is_type_list(gt_radius_map, BitmapMasks) assert check_argument.is_type_list(gt_sin_map, BitmapMasks) assert check_argument.is_type_list(gt_cos_map, BitmapMasks) pred_text_region = pred_maps[:, 0, :, :] pred_center_region = pred_maps[:, 1, :, :] pred_sin_map = pred_maps[:, 2, :, :] pred_cos_map = pred_maps[:, 3, :, :] pred_radius_map = pred_maps[:, 4, :, :] feature_sz = pred_maps.size() device = pred_maps.device # bitmask 2 tensor mapping = { 'gt_text_mask': gt_text_mask, 'gt_center_region_mask': gt_center_region_mask, 'gt_mask': gt_mask, 'gt_radius_map': gt_radius_map, 'gt_sin_map': gt_sin_map, 'gt_cos_map': gt_cos_map } gt = {} for key, value in mapping.items(): gt[key] = value if abs(downsample_ratio - 1.0) < 1e-2: gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) else: gt[key] = [item.rescale(downsample_ratio) for item in gt[key]] gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) if key == 'gt_radius_map': gt[key] = [item * downsample_ratio for item in gt[key]] gt[key] = [item.to(device) for item in gt[key]] scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8)) pred_sin_map = pred_sin_map * scale pred_cos_map = pred_cos_map * scale loss_text = self.balanced_bce_loss(torch.sigmoid(pred_text_region), gt['gt_text_mask'][0], gt['gt_mask'][0]) text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float() loss_center_map = F.binary_cross_entropy( torch.sigmoid(pred_center_region), gt['gt_center_region_mask'][0].float(), reduction='none') if int(text_mask.sum()) > 0: loss_center = torch.sum( loss_center_map * text_mask) / torch.sum(text_mask) else: loss_center = torch.tensor(0.0, device=device) center_mask = (gt['gt_center_region_mask'][0] * gt['gt_mask'][0]).float() if int(center_mask.sum()) > 0: map_sz = pred_radius_map.size() ones = torch.ones(map_sz, dtype=torch.float, device=device) loss_radius = torch.sum( F.smooth_l1_loss(pred_radius_map / (gt['gt_radius_map'][0] + 1e-2), ones, reduction='none') * center_mask) / torch.sum(center_mask) loss_sin = torch.sum( F.smooth_l1_loss( pred_sin_map, gt['gt_sin_map'][0], reduction='none') * center_mask) / torch.sum(center_mask) loss_cos = torch.sum( F.smooth_l1_loss( pred_cos_map, gt['gt_cos_map'][0], reduction='none') * center_mask) / torch.sum(center_mask) else: loss_radius = torch.tensor(0.0, device=device) loss_sin = torch.tensor(0.0, device=device) loss_cos = torch.tensor(0.0, device=device) results = dict(loss_text=loss_text, loss_center=loss_center, loss_radius=loss_radius, loss_sin=loss_sin, loss_cos=loss_cos) return results