Ejemplo n.º 1
0
def test_single_gpu_test_kie_novisual(cfg_file):
    curr_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
    config_file = os.path.join(curr_dir, cfg_file)
    cfg = Config.fromfile(config_file)
    meta_keys = list(cfg.data.test.pipeline[-1]['meta_keys'])
    must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape']
    for key in must_keys:
        meta_keys.append(key)

    cfg.data.test.pipeline[-1]['meta_keys'] = tuple(meta_keys)

    with tempfile.TemporaryDirectory() as tmpdirname:
        out_dir = osp.join(tmpdirname, 'tmp')
        model, data_loader = gene_sdmgr_model_dataloader(cfg,
                                                         out_dir,
                                                         curr_dir,
                                                         empty_img=True)
        results = single_gpu_test(model,
                                  data_loader,
                                  out_dir=out_dir,
                                  is_kie=True)
        assert check_argument.is_type_list(results, dict)

        model, data_loader = gene_sdmgr_model_dataloader(
            cfg, out_dir, curr_dir)
        results = single_gpu_test(model,
                                  data_loader,
                                  out_dir=out_dir,
                                  is_kie=True)
        assert check_argument.is_type_list(results, dict)
Ejemplo n.º 2
0
    def get_boundary(self, edges, scores, text_comps, img_metas, rescale):
        """Compute text boundaries via post processing.

        Args:
            edges (ndarray): The edge array of shape N * 2, each row is a pair
                of text component indices that makes up an edge in graph.
            scores (ndarray): The edge score array.
            text_comps (ndarray): The text components.
            img_metas (list[dict]): The image meta infos.
            rescale (bool): Rescale boundaries to the original image
                resolution.

        Returns:
            results (dict): The result dict.
        """

        assert check_argument.is_type_list(img_metas, dict)
        assert isinstance(rescale, bool)

        boundaries = []
        if edges is not None:
            boundaries = decode(decoding_type='drrg',
                                edges=edges,
                                scores=scores,
                                text_comps=text_comps,
                                link_thr=self.link_thr)
        if rescale:
            boundaries = self.resize_boundary(
                boundaries,
                1.0 / self.downsample_ratio / img_metas[0]['scale_factor'])

        results = dict(boundary_result=boundaries)

        return results
Ejemplo n.º 3
0
    def get_boundary(self, score_maps, img_metas, rescale):
        """Compute text boundaries via post processing.

        Args:
            score_maps (Tensor): The text score map.
            img_metas (dict): The image meta info.
            rescale (bool): Rescale boundaries to the original image resolution
                if true, and keep the score_maps resolution if false.

        Returns:
            dict: A dict where boundary results are stored in
            ``boundary_result``.
        """

        assert check_argument.is_type_list(img_metas, dict)
        assert isinstance(rescale, bool)

        score_maps = score_maps.squeeze()
        boundaries = decode(decoding_type=self.decoding_type,
                            preds=score_maps,
                            text_repr_type=self.text_repr_type)
        if rescale:
            boundaries = self.resize_boundary(
                boundaries,
                1.0 / self.downsample_ratio / img_metas[0]['scale_factor'])
        results = dict(boundary_result=boundaries,
                       filename=img_metas[0]['filename'])

        return results
Ejemplo n.º 4
0
def trace_boundary(char_boxes):
    """Trace the boundary point of text.

    Args:
        char_boxes (list[ndarray]): The char boxes for one text. Each element
        is 4x2 ndarray.

    Returns:
        boundary (ndarray): The boundary point sets with size nx2.
    """
    assert check_argument.is_type_list(char_boxes, np.ndarray)

    # from top left to to right
    p_top = [box[0:2] for box in char_boxes]
    # from bottom right to bottom left
    p_bottom = [
        char_boxes[idx][[2, 3], :]
        for idx in range(len(char_boxes) - 1, -1, -1)
    ]

    p = p_top + p_bottom

    boundary = np.concatenate(p).astype(int)

    return boundary
Ejemplo n.º 5
0
    def __init__(
            self,
            in_channels,
            out_channels,
            text_repr_type='poly',  # 'poly' or 'quad'
            downsample_ratio=0.25,
            loss=dict(type='PANLoss'),
            train_cfg=None,
            test_cfg=None):
        super().__init__()

        assert check_argument.is_type_list(in_channels, int)
        assert isinstance(out_channels, int)
        assert text_repr_type in ['poly', 'quad']
        assert 0 <= downsample_ratio <= 1

        self.loss_module = build_loss(loss)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.text_repr_type = text_repr_type
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.downsample_ratio = downsample_ratio
        if loss['type'] == 'PANLoss':
            self.decoding_type = 'pan'
        elif loss['type'] == 'PSELoss':
            self.decoding_type = 'pse'
        else:
            type = loss['type']
            raise NotImplementedError(f'unsupported loss type {type}.')

        self.out_conv = nn.Conv2d(in_channels=np.sum(np.array(in_channels)),
                                  out_channels=out_channels,
                                  kernel_size=1)
        self.init_weights()
Ejemplo n.º 6
0
def test_single_gpu_test_kie(cfg_file):
    curr_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
    config_file = os.path.join(curr_dir, cfg_file)
    cfg = Config.fromfile(config_file)

    with tempfile.TemporaryDirectory() as tmpdirname:
        out_dir = osp.join(tmpdirname, 'tmp')
        model, data_loader = gene_sdmgr_model_dataloader(
            cfg, out_dir, curr_dir)
        results = single_gpu_test(model,
                                  data_loader,
                                  out_dir=out_dir,
                                  is_kie=True)
        assert check_argument.is_type_list(results, dict)
Ejemplo n.º 7
0
def test_single_gpu_test_det(cfg_file):
    curr_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
    config_file = os.path.join(curr_dir, cfg_file)
    cfg = Config.fromfile(config_file)

    model = build_model(cfg)
    img_prefix = 'data/toy_dataset/imgs'
    ann_file = 'data/toy_dataset/instances_test.json'
    data_loader = generate_sample_dataloader(cfg, curr_dir, img_prefix,
                                             ann_file)

    with tempfile.TemporaryDirectory() as tmpdirname:
        out_dir = osp.join(tmpdirname, 'tmp')
        results = single_gpu_test(model, data_loader, out_dir=out_dir)
        assert check_argument.is_type_list(results, dict)
Ejemplo n.º 8
0
    def generate_kernels(self,
                         resize_shape,
                         pad_shape,
                         char_boxes,
                         char_inds,
                         shrink_ratio=0.5,
                         binary=True):
        """Generate char instance kernels for one shrink ratio.

        Args:
            resize_shape (tuple(int, int)): Image size (height, width)
                after resizing.
            pad_shape (tuple(int, int)):  Image size (height, width)
                after padding.
            char_boxes (list[list[float]]): The list of char polygons.
            char_inds (list[int]): List of char indexes.
            shrink_ratio (float): The shrink ratio of kernel.
            binary (bool): If True, return binary ndarray
                containing 0 & 1 only.
        Returns:
            char_kernel (ndarray): The text kernel mask of (height, width).
        """
        assert isinstance(resize_shape, tuple)
        assert isinstance(pad_shape, tuple)
        assert check_argument.is_2dlist(char_boxes)
        assert check_argument.is_type_list(char_inds, int)
        assert isinstance(shrink_ratio, float)
        assert isinstance(binary, bool)

        char_kernel = np.zeros(pad_shape, dtype=np.int32)
        char_kernel[:resize_shape[0], resize_shape[1]:] = self.pad_val

        for i, char_box in enumerate(char_boxes):
            if self.box_type == 'char_rects':
                poly = self.shrink_char_rect(char_box, shrink_ratio)
            elif self.box_type == 'char_quads':
                poly = self.shrink_char_quad(char_box, shrink_ratio)

            fill_value = 1 if binary else char_inds[i]
            cv2.fillConvexPoly(char_kernel, poly.astype(np.int32),
                               (fill_value))

        return char_kernel
Ejemplo n.º 9
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 downsample_ratio=0.25,
                 loss=dict(type='PANLoss'),
                 postprocessor=dict(
                     type='PANPostprocessor', text_repr_type='poly'),
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=dict(
                     type='Normal',
                     mean=0,
                     std=0.01,
                     override=dict(name='out_conv')),
                 **kwargs):
        old_keys = ['text_repr_type', 'decoding_type']
        for key in old_keys:
            if kwargs.get(key, None):
                postprocessor[key] = kwargs.get(key)
                warnings.warn(
                    f'{key} is deprecated, please specify '
                    'it in postprocessor config dict. See '
                    'https://github.com/open-mmlab/mmocr/pull/640'
                    ' for details.', UserWarning)

        BaseModule.__init__(self, init_cfg=init_cfg)
        HeadMixin.__init__(self, loss, postprocessor)

        assert check_argument.is_type_list(in_channels, int)
        assert isinstance(out_channels, int)

        assert 0 <= downsample_ratio <= 1

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsample_ratio = downsample_ratio
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.out_conv = nn.Conv2d(
            in_channels=np.sum(np.array(in_channels)),
            out_channels=out_channels,
            kernel_size=1)
Ejemplo n.º 10
0
def sort_points(points):
    """Sort arbitory points in clockwise order. Reference:
    https://stackoverflow.com/a/6989383.

    Args:
        points (list[ndarray] or ndarray or list[list]): A list of unsorted
            boundary points.

    Returns:
        list[ndarray]: A list of points sorted in clockwise order.
    """

    assert is_type_list(points, np.ndarray) or isinstance(points, np.ndarray) \
        or is_2dlist(points)

    points = np.array(points)
    center = np.mean(points, axis=0)

    def cmp(a, b):
        oa = a - center
        ob = b - center

        # Some corner cases
        if oa[0] >= 0 and ob[0] < 0:
            return 1
        if oa[0] < 0 and ob[0] >= 0:
            return -1

        prod = np.cross(oa, ob)
        if prod > 0:
            return 1
        if prod < 0:
            return -1

        # a, b are on the same line from the center
        return 1 if (oa**2).sum() < (ob**2).sum() else -1

    return sorted(points, key=functools.cmp_to_key(cmp))
Ejemplo n.º 11
0
    def bitmasks2tensor(self, bitmasks, target_sz):
        """Convert Bitmasks to tensor.

        Args:
            bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
                for one img.
            target_sz (tuple(int, int)): The target tensor of size
                :math:`(H, W)`.

        Returns:
            list[Tensor]: The list of kernel tensors. Each element stands for
            one kernel level.
        """
        assert check_argument.is_type_list(bitmasks, BitmapMasks)
        assert isinstance(target_sz, tuple)

        batch_size = len(bitmasks)
        num_masks = len(bitmasks[0])

        results = []

        for level_inx in range(num_masks):
            kernel = []
            for batch_inx in range(batch_size):
                mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
                # hxw
                mask_sz = mask.shape
                # left, right, top, bottom
                pad = [
                    0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
                ]
                mask = F.pad(mask, pad, mode='constant', value=0)
                kernel.append(mask)
            kernel = torch.stack(kernel)
            results.append(kernel)

        return results
Ejemplo n.º 12
0
    def forward(self, preds, downsample_ratio, gt_text_mask,
                gt_center_region_mask, gt_mask, gt_top_height_map,
                gt_bot_height_map, gt_sin_map, gt_cos_map):
        """Compute Drrg loss.

        Args:
            preds (tuple(Tensor)): The first is the prediction map
                with shape :math:`(N, C_{out}, H, W)`.
                The second is prediction from GCN module, with
                shape :math:`(N, 2)`.
                The third is ground-truth label with shape :math:`(N, 8)`.
            downsample_ratio (float): The downsample ratio.
            gt_text_mask (list[BitmapMasks]): Text mask.
            gt_center_region_mask (list[BitmapMasks]): Center region mask.
            gt_mask (list[BitmapMasks]): Effective mask.
            gt_top_height_map (list[BitmapMasks]): Top height map.
            gt_bot_height_map (list[BitmapMasks]): Bottom height map.
            gt_sin_map (list[BitmapMasks]): Sinusoid map.
            gt_cos_map (list[BitmapMasks]): Cosine map.

        Returns:
            dict:  A loss dict with ``loss_text``, ``loss_center``,
            ``loss_height``, ``loss_sin``, ``loss_cos``, and ``loss_gcn``.
        """
        assert isinstance(preds, tuple)
        assert isinstance(downsample_ratio, float)
        assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
        assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
        assert check_argument.is_type_list(gt_mask, BitmapMasks)
        assert check_argument.is_type_list(gt_top_height_map, BitmapMasks)
        assert check_argument.is_type_list(gt_bot_height_map, BitmapMasks)
        assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
        assert check_argument.is_type_list(gt_cos_map, BitmapMasks)

        pred_maps, gcn_data = preds
        pred_text_region = pred_maps[:, 0, :, :]
        pred_center_region = pred_maps[:, 1, :, :]
        pred_sin_map = pred_maps[:, 2, :, :]
        pred_cos_map = pred_maps[:, 3, :, :]
        pred_top_height_map = pred_maps[:, 4, :, :]
        pred_bot_height_map = pred_maps[:, 5, :, :]
        feature_sz = pred_maps.size()
        device = pred_maps.device

        # bitmask 2 tensor
        mapping = {
            'gt_text_mask': gt_text_mask,
            'gt_center_region_mask': gt_center_region_mask,
            'gt_mask': gt_mask,
            'gt_top_height_map': gt_top_height_map,
            'gt_bot_height_map': gt_bot_height_map,
            'gt_sin_map': gt_sin_map,
            'gt_cos_map': gt_cos_map
        }
        gt = {}
        for key, value in mapping.items():
            gt[key] = value
            if abs(downsample_ratio - 1.0) < 1e-2:
                gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
            else:
                gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
                gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
                if key in ['gt_top_height_map', 'gt_bot_height_map']:
                    gt[key] = [item * downsample_ratio for item in gt[key]]
            gt[key] = [item.to(device) for item in gt[key]]

        scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
        pred_sin_map = pred_sin_map * scale
        pred_cos_map = pred_cos_map * scale

        loss_text = self.balance_bce_loss(torch.sigmoid(pred_text_region),
                                          gt['gt_text_mask'][0],
                                          gt['gt_mask'][0])

        text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
        negative_text_mask = ((1 - gt['gt_text_mask'][0]) *
                              gt['gt_mask'][0]).float()
        loss_center_map = F.binary_cross_entropy(
            torch.sigmoid(pred_center_region),
            gt['gt_center_region_mask'][0].float(),
            reduction='none')
        if int(text_mask.sum()) > 0:
            loss_center_positive = torch.sum(
                loss_center_map * text_mask) / torch.sum(text_mask)
        else:
            loss_center_positive = torch.tensor(0.0, device=device)
        loss_center_negative = torch.sum(
            loss_center_map *
            negative_text_mask) / torch.sum(negative_text_mask)
        loss_center = loss_center_positive + 0.5 * loss_center_negative

        center_mask = (gt['gt_center_region_mask'][0] *
                       gt['gt_mask'][0]).float()
        if int(center_mask.sum()) > 0:
            map_sz = pred_top_height_map.size()
            ones = torch.ones(map_sz, dtype=torch.float, device=device)
            loss_top = F.smooth_l1_loss(pred_top_height_map /
                                        (gt['gt_top_height_map'][0] + 1e-2),
                                        ones,
                                        reduction='none')
            loss_bot = F.smooth_l1_loss(pred_bot_height_map /
                                        (gt['gt_bot_height_map'][0] + 1e-2),
                                        ones,
                                        reduction='none')
            gt_height = (gt['gt_top_height_map'][0] +
                         gt['gt_bot_height_map'][0])
            loss_height = torch.sum(
                (torch.log(gt_height + 1) *
                 (loss_top + loss_bot)) * center_mask) / torch.sum(center_mask)

            loss_sin = torch.sum(
                F.smooth_l1_loss(
                    pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
                center_mask) / torch.sum(center_mask)
            loss_cos = torch.sum(
                F.smooth_l1_loss(
                    pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
                center_mask) / torch.sum(center_mask)
        else:
            loss_height = torch.tensor(0.0, device=device)
            loss_sin = torch.tensor(0.0, device=device)
            loss_cos = torch.tensor(0.0, device=device)

        loss_gcn = self.gcn_loss(gcn_data)

        results = dict(loss_text=loss_text,
                       loss_center=loss_center,
                       loss_height=loss_height,
                       loss_sin=loss_sin,
                       loss_cos=loss_cos,
                       loss_gcn=loss_gcn)

        return results
Ejemplo n.º 13
0
    def forward(self, score_maps, downsample_ratio, gt_kernels, gt_mask):
        """Compute PSENet loss.

        Args:
            score_maps (tensor): The output tensor with size of Nx6xHxW.
            gt_kernels (list[BitmapMasks]): The kernel list with each element
                being the text kernel mask for one img.
            gt_mask (list[BitmapMasks]): The effective mask list
                with each element being the effective mask fo one img.
            downsample_ratio (float): The downsample ratio between score_maps
                and the input img.

        Returns:
            results (dict): The loss.
        """

        assert check_argument.is_type_list(gt_kernels, BitmapMasks)
        assert check_argument.is_type_list(gt_mask, BitmapMasks)
        assert isinstance(downsample_ratio, float)
        losses = []

        pred_texts = score_maps[:, 0, :, :]
        pred_kernels = score_maps[:, 1:, :, :]
        feature_sz = score_maps.size()

        gt_kernels = [item.rescale(downsample_ratio) for item in gt_kernels]
        gt_kernels = self.bitmasks2tensor(gt_kernels, feature_sz[2:])
        gt_kernels = [item.to(score_maps.device) for item in gt_kernels]

        gt_mask = [item.rescale(downsample_ratio) for item in gt_mask]
        gt_mask = self.bitmasks2tensor(gt_mask, feature_sz[2:])
        gt_mask = [item.to(score_maps.device) for item in gt_mask]

        # compute text loss
        sampled_masks_text = self.ohem_batch(pred_texts.detach(),
                                             gt_kernels[0], gt_mask[0])
        loss_texts = self.dice_loss_with_logits(pred_texts, gt_kernels[0],
                                                sampled_masks_text)
        losses.append(self.alpha * loss_texts)

        # compute kernel loss
        if self.kernel_sample_type == 'hard':
            sampled_masks_kernel = (gt_kernels[0] > 0.5).float() * (
                gt_mask[0].float())
        elif self.kernel_sample_type == 'adaptive':
            sampled_masks_kernel = (pred_texts > 0).float() * (
                gt_mask[0].float())
        else:
            raise NotImplementedError

        num_kernel = pred_kernels.shape[1]
        assert num_kernel == len(gt_kernels) - 1
        loss_list = []
        for idx in range(num_kernel):
            loss_kernels = self.dice_loss_with_logits(
                pred_kernels[:, idx, :, :], gt_kernels[1 + idx],
                sampled_masks_kernel)
            loss_list.append(loss_kernels)

        losses.append((1 - self.alpha) * sum(loss_list) / len(loss_list))

        if self.reduction == 'mean':
            losses = [item.mean() for item in losses]
        elif self.reduction == 'sum':
            losses = [item.sum() for item in losses]
        else:
            raise NotImplementedError
        results = dict(loss_text=losses[0], loss_kernel=losses[1])
        return results
Ejemplo n.º 14
0
    def forward(self, pred_maps, downsample_ratio, gt_text_mask,
                gt_center_region_mask, gt_mask, gt_radius_map, gt_sin_map,
                gt_cos_map):
        """
        Args:
            pred_maps (Tensor): The prediction map of shape
                :math:`(N, 5, H, W)`, where each dimension is the map of
                "text_region", "center_region", "sin_map", "cos_map", and
                "radius_map" respectively.
            downsample_ratio (float): Downsample ratio.
            gt_text_mask (list[BitmapMasks]): Gold text masks.
            gt_center_region_mask (list[BitmapMasks]): Gold center region
                masks.
            gt_mask (list[BitmapMasks]): Gold general masks.
            gt_radius_map (list[BitmapMasks]): Gold radius maps.
            gt_sin_map (list[BitmapMasks]): Gold sin maps.
            gt_cos_map (list[BitmapMasks]): Gold cos maps.

        Returns:
            dict:  A loss dict with ``loss_text``, ``loss_center``,
            ``loss_radius``, ``loss_sin`` and ``loss_cos``.
        """

        assert isinstance(downsample_ratio, float)
        assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
        assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
        assert check_argument.is_type_list(gt_mask, BitmapMasks)
        assert check_argument.is_type_list(gt_radius_map, BitmapMasks)
        assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
        assert check_argument.is_type_list(gt_cos_map, BitmapMasks)

        pred_text_region = pred_maps[:, 0, :, :]
        pred_center_region = pred_maps[:, 1, :, :]
        pred_sin_map = pred_maps[:, 2, :, :]
        pred_cos_map = pred_maps[:, 3, :, :]
        pred_radius_map = pred_maps[:, 4, :, :]
        feature_sz = pred_maps.size()
        device = pred_maps.device

        # bitmask 2 tensor
        mapping = {
            'gt_text_mask': gt_text_mask,
            'gt_center_region_mask': gt_center_region_mask,
            'gt_mask': gt_mask,
            'gt_radius_map': gt_radius_map,
            'gt_sin_map': gt_sin_map,
            'gt_cos_map': gt_cos_map
        }
        gt = {}
        for key, value in mapping.items():
            gt[key] = value
            if abs(downsample_ratio - 1.0) < 1e-2:
                gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
            else:
                gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
                gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
                if key == 'gt_radius_map':
                    gt[key] = [item * downsample_ratio for item in gt[key]]
            gt[key] = [item.to(device) for item in gt[key]]

        scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
        pred_sin_map = pred_sin_map * scale
        pred_cos_map = pred_cos_map * scale

        loss_text = self.balanced_bce_loss(
            torch.sigmoid(pred_text_region), gt['gt_text_mask'][0],
            gt['gt_mask'][0])

        text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
        loss_center_map = F.binary_cross_entropy(
            torch.sigmoid(pred_center_region),
            gt['gt_center_region_mask'][0].float(),
            reduction='none')
        if int(text_mask.sum()) > 0:
            loss_center = torch.sum(
                loss_center_map * text_mask) / torch.sum(text_mask)
        else:
            loss_center = torch.tensor(0.0, device=device)

        center_mask = (gt['gt_center_region_mask'][0] *
                       gt['gt_mask'][0]).float()
        if int(center_mask.sum()) > 0:
            map_sz = pred_radius_map.size()
            ones = torch.ones(map_sz, dtype=torch.float, device=device)
            loss_radius = torch.sum(
                F.smooth_l1_loss(
                    pred_radius_map / (gt['gt_radius_map'][0] + 1e-2),
                    ones,
                    reduction='none') * center_mask) / torch.sum(center_mask)
            loss_sin = torch.sum(
                F.smooth_l1_loss(
                    pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
                center_mask) / torch.sum(center_mask)
            loss_cos = torch.sum(
                F.smooth_l1_loss(
                    pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
                center_mask) / torch.sum(center_mask)
        else:
            loss_radius = torch.tensor(0.0, device=device)
            loss_sin = torch.tensor(0.0, device=device)
            loss_cos = torch.tensor(0.0, device=device)

        results = dict(
            loss_text=loss_text,
            loss_center=loss_center,
            loss_radius=loss_radius,
            loss_sin=loss_sin,
            loss_cos=loss_cos)

        return results
Ejemplo n.º 15
0
    def forward(self, preds, downsample_ratio, gt_kernels, gt_mask):
        """Compute PANet loss.

        Args:
            preds (tensor): The output tensor with size of Nx6xHxW.
            gt_kernels (list[BitmapMasks]): The kernel list with each element
                being the text kernel mask for one img.
            gt_mask (list[BitmapMasks]): The effective mask list
                with each element being the effective mask fo one img.
            downsample_ratio (float): The downsample ratio between preds
                and the input img.

        Returns:
            results (dict): The loss dictionary.
        """

        assert check_argument.is_type_list(gt_kernels, BitmapMasks)
        assert check_argument.is_type_list(gt_mask, BitmapMasks)
        assert isinstance(downsample_ratio, float)

        pred_texts = preds[:, 0, :, :]
        pred_kernels = preds[:, 1, :, :]
        inst_embed = preds[:, 2:, :, :]
        feature_sz = preds.size()

        mapping = {'gt_kernels': gt_kernels, 'gt_mask': gt_mask}
        gt = {}
        for key, value in mapping.items():
            gt[key] = value
            gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
            gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
            gt[key] = [item.to(preds.device) for item in gt[key]]
        loss_aggrs, loss_discrs = self.aggregation_discrimination_loss(
            gt['gt_kernels'][0], gt['gt_kernels'][1], inst_embed)
        # compute text loss
        sampled_mask = self.ohem_batch(pred_texts.detach(),
                                       gt['gt_kernels'][0], gt['gt_mask'][0])
        loss_texts = self.dice_loss_with_logits(pred_texts,
                                                gt['gt_kernels'][0],
                                                sampled_mask)

        # compute kernel loss

        sampled_masks_kernel = (gt['gt_kernels'][0] >
                                0.5).float() * (gt['gt_mask'][0].float())
        loss_kernels = self.dice_loss_with_logits(pred_kernels,
                                                  gt['gt_kernels'][1],
                                                  sampled_masks_kernel)
        losses = [loss_texts, loss_kernels, loss_aggrs, loss_discrs]
        if self.reduction == 'mean':
            losses = [item.mean() for item in losses]
        elif self.reduction == 'sum':
            losses = [item.sum() for item in losses]
        else:
            raise NotImplementedError

        coefs = [1, self.alpha, self.beta, self.beta]
        losses = [item * scale for item, scale in zip(losses, coefs)]

        results = dict()
        results.update(loss_text=losses[0],
                       loss_kernel=losses[1],
                       loss_aggregation=losses[2],
                       loss_discrimination=losses[3])
        return results
Ejemplo n.º 16
0
    def forward(self, pred_maps, downsample_ratio, gt_text_mask,
                gt_center_region_mask, gt_mask, gt_radius_map, gt_sin_map,
                gt_cos_map):

        assert isinstance(downsample_ratio, float)
        assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
        assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
        assert check_argument.is_type_list(gt_mask, BitmapMasks)
        assert check_argument.is_type_list(gt_radius_map, BitmapMasks)
        assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
        assert check_argument.is_type_list(gt_cos_map, BitmapMasks)

        pred_text_region = pred_maps[:, 0, :, :]
        pred_center_region = pred_maps[:, 1, :, :]
        pred_sin_map = pred_maps[:, 2, :, :]
        pred_cos_map = pred_maps[:, 3, :, :]
        pred_radius_map = pred_maps[:, 4, :, :]
        feature_sz = pred_maps.size()
        device = pred_maps.device

        # bitmask 2 tensor
        mapping = {
            'gt_text_mask': gt_text_mask,
            'gt_center_region_mask': gt_center_region_mask,
            'gt_mask': gt_mask,
            'gt_radius_map': gt_radius_map,
            'gt_sin_map': gt_sin_map,
            'gt_cos_map': gt_cos_map
        }
        gt = {}
        for key, value in mapping.items():
            gt[key] = value
            if abs(downsample_ratio - 1.0) < 1e-2:
                gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
            else:
                gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
                gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
                if key == 'gt_radius_map':
                    gt[key] = [item * downsample_ratio for item in gt[key]]
            gt[key] = [item.to(device) for item in gt[key]]

        scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
        pred_sin_map = pred_sin_map * scale
        pred_cos_map = pred_cos_map * scale

        loss_text = self.balanced_bce_loss(torch.sigmoid(pred_text_region),
                                           gt['gt_text_mask'][0],
                                           gt['gt_mask'][0])

        text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
        loss_center_map = F.binary_cross_entropy(
            torch.sigmoid(pred_center_region),
            gt['gt_center_region_mask'][0].float(),
            reduction='none')
        if int(text_mask.sum()) > 0:
            loss_center = torch.sum(
                loss_center_map * text_mask) / torch.sum(text_mask)
        else:
            loss_center = torch.tensor(0.0, device=device)

        center_mask = (gt['gt_center_region_mask'][0] *
                       gt['gt_mask'][0]).float()
        if int(center_mask.sum()) > 0:
            map_sz = pred_radius_map.size()
            ones = torch.ones(map_sz, dtype=torch.float, device=device)
            loss_radius = torch.sum(
                F.smooth_l1_loss(pred_radius_map /
                                 (gt['gt_radius_map'][0] + 1e-2),
                                 ones,
                                 reduction='none') *
                center_mask) / torch.sum(center_mask)
            loss_sin = torch.sum(
                F.smooth_l1_loss(
                    pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
                center_mask) / torch.sum(center_mask)
            loss_cos = torch.sum(
                F.smooth_l1_loss(
                    pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
                center_mask) / torch.sum(center_mask)
        else:
            loss_radius = torch.tensor(0.0, device=device)
            loss_sin = torch.tensor(0.0, device=device)
            loss_cos = torch.tensor(0.0, device=device)

        results = dict(loss_text=loss_text,
                       loss_center=loss_center,
                       loss_radius=loss_radius,
                       loss_sin=loss_sin,
                       loss_cos=loss_cos)

        return results