Exemplo n.º 1
0
    def __init__(self,
                 num_chars=92,
                 visual_dim=64,
                 fusion_dim=1024,
                 node_input=32,
                 node_embed=256,
                 edge_input=5,
                 edge_embed=256,
                 num_gnn=2,
                 num_classes=26,
                 loss=dict(type='SDMGRLoss'),
                 bidirectional=False,
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=dict(type='Normal',
                               override=dict(name='edge_embed'),
                               mean=0,
                               std=0.01)):
        super().__init__(init_cfg=init_cfg)

        self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
        self.node_embed = nn.Embedding(num_chars, node_input, 0)
        hidden = node_embed // 2 if bidirectional else node_embed
        self.rnn = nn.LSTM(input_size=node_input,
                           hidden_size=hidden,
                           num_layers=1,
                           batch_first=True,
                           bidirectional=bidirectional)
        self.edge_embed = nn.Linear(edge_input, edge_embed)
        self.gnn_layers = nn.ModuleList(
            [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
        self.node_cls = nn.Linear(node_embed, num_classes)
        self.edge_cls = nn.Linear(edge_embed, 2)
        self.loss = build_loss(loss)
Exemplo n.º 2
0
    def __init__(self,
                 in_channels,
                 decoding_type='textsnake',
                 text_repr_type='poly',
                 loss=dict(type='TextSnakeLoss'),
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=dict(type='Normal',
                               override=dict(name='out_conv'),
                               mean=0,
                               std=0.01)):
        super().__init__(init_cfg=init_cfg)

        assert isinstance(in_channels, int)
        self.in_channels = in_channels
        self.out_channels = 5
        self.downsample_ratio = 1.0
        self.decoding_type = decoding_type
        self.text_repr_type = text_repr_type
        self.loss_module = build_loss(loss)
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.out_conv = nn.Conv2d(in_channels=self.in_channels,
                                  out_channels=self.out_channels,
                                  kernel_size=1,
                                  stride=1,
                                  padding=0)
Exemplo n.º 3
0
    def __init__(
        self,
        in_channels,
        scales,
        fourier_degree=5,
        num_sample=50,
        num_reconstr_points=50,
        decoding_type='fcenet',
        loss=dict(type='FCELoss'),
        score_thr=0.3,
        nms_thr=0.1,
        alpha=1.0,
        beta=1.0,
        text_repr_type='poly',
        train_cfg=None,
        test_cfg=None,
        init_cfg=dict(
            type='Normal',
            mean=0,
            std=0.01,
            override=[dict(name='out_conv_cls'),
                      dict(name='out_conv_reg')])):

        super().__init__(init_cfg=init_cfg)
        assert isinstance(in_channels, int)

        self.downsample_ratio = 1.0
        self.in_channels = in_channels
        self.scales = scales
        self.fourier_degree = fourier_degree
        self.sample_num = num_sample
        self.num_reconstr_points = num_reconstr_points
        loss['fourier_degree'] = fourier_degree
        loss['num_sample'] = num_sample
        self.decoding_type = decoding_type
        self.loss_module = build_loss(loss)
        self.score_thr = score_thr
        self.nms_thr = nms_thr
        self.alpha = alpha
        self.beta = beta
        self.text_repr_type = text_repr_type
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.out_channels_cls = 4
        self.out_channels_reg = (2 * self.fourier_degree + 1) * 2

        self.out_conv_cls = nn.Conv2d(
            self.in_channels,
            self.out_channels_cls,
            kernel_size=3,
            stride=1,
            padding=1)
        self.out_conv_reg = nn.Conv2d(
            self.in_channels,
            self.out_channels_reg,
            kernel_size=3,
            stride=1,
            padding=1)
    def __init__(self,
                 preprocessor=None,
                 backbone=None,
                 encoder=None,
                 decoder=None,
                 loss=None,
                 label_convertor=None,
                 train_cfg=None,
                 test_cfg=None,
                 max_seq_len=40,
                 pretrained=None,
                 init_cfg=None):

        super().__init__(init_cfg=init_cfg)

        # Label convertor (str2tensor, tensor2str)
        assert label_convertor is not None
        label_convertor.update(max_seq_len=max_seq_len)
        self.label_convertor = build_convertor(label_convertor)

        # Preprocessor module, e.g., TPS
        self.preprocessor = None
        if preprocessor is not None:
            self.preprocessor = build_preprocessor(preprocessor)

        # Backbone
        assert backbone is not None
        self.backbone = build_backbone(backbone)

        # Encoder module
        self.encoder = None
        if encoder is not None:
            self.encoder = build_encoder(encoder)

        # Decoder module
        assert decoder is not None
        decoder.update(num_classes=self.label_convertor.num_classes())
        decoder.update(start_idx=self.label_convertor.start_idx)
        decoder.update(padding_idx=self.label_convertor.padding_idx)
        decoder.update(max_seq_len=max_seq_len)
        self.decoder = build_decoder(decoder)

        # Loss
        assert loss is not None
        loss.update(ignore_index=self.label_convertor.padding_idx)
        self.loss = build_loss(loss)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.max_seq_len = max_seq_len

        if pretrained is not None:
            warnings.warn('DeprecationWarning: pretrained is a deprecated \
                key, please consider using init_cfg')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
Exemplo n.º 5
0
    def __init__(self,
                 preprocessor=None,
                 backbone=None,
                 neck=None,
                 head=None,
                 loss=None,
                 label_convertor=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)

        # Label_convertor
        assert label_convertor is not None
        self.label_convertor = build_convertor(label_convertor)

        # Preprocessor module, e.g., TPS
        self.preprocessor = None
        if preprocessor is not None:
            self.preprocessor = build_preprocessor(preprocessor)

        # Backbone
        assert backbone is not None
        self.backbone = build_backbone(backbone)

        # Neck
        assert neck is not None
        self.neck = build_neck(neck)

        # Head
        assert head is not None
        head.update(num_classes=self.label_convertor.num_classes())
        self.head = build_head(head)

        # Loss
        assert loss is not None
        self.loss = build_loss(loss)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        if pretrained is not None:
            warnings.warn('DeprecationWarning: pretrained is a deprecated \
                key, please consider using init_cfg')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
Exemplo n.º 6
0
    def __init__(self,
                 encoder,
                 decoder,
                 loss,
                 label_convertor,
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.label_convertor = build_convertor(label_convertor)

        self.encoder = build_encoder(encoder)

        decoder.update(num_labels=self.label_convertor.num_labels)
        self.decoder = build_decoder(decoder)

        loss.update(num_labels=self.label_convertor.num_labels)
        self.loss = build_loss(loss)
Exemplo n.º 7
0
    def __init__(self,
                 in_channels,
                 with_bias=False,
                 decoding_type='db',
                 text_repr_type='poly',
                 downsample_ratio=1.0,
                 loss=dict(type='DBLoss'),
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=[
                     dict(type='Kaiming', layer='Conv'),
                     dict(type='Constant',
                          layer='BatchNorm',
                          val=1.,
                          bias=1e-4)
                 ]):
        super().__init__(init_cfg=init_cfg)

        assert isinstance(in_channels, int)

        self.in_channels = in_channels
        self.text_repr_type = text_repr_type
        self.loss_module = build_loss(loss)
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.downsample_ratio = downsample_ratio
        self.decoding_type = decoding_type

        self.binarize = Sequential(
            nn.Conv2d(in_channels,
                      in_channels // 4,
                      3,
                      bias=with_bias,
                      padding=1), nn.BatchNorm2d(in_channels // 4),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
            nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())

        self.threshold = self._init_thr(in_channels)
Exemplo n.º 8
0
    def __init__(
        self,
        in_channels,
        out_channels,
        text_repr_type='poly',  # 'poly' or 'quad'
        downsample_ratio=0.25,
        loss=dict(type='PANLoss'),
        train_cfg=None,
        test_cfg=None,
        init_cfg=dict(
            type='Normal', mean=0, std=0.01, override=dict(name='out_conv'))):
        super().__init__(init_cfg=init_cfg)

        assert check_argument.is_type_list(in_channels, int)
        assert isinstance(out_channels, int)
        assert text_repr_type in ['poly', 'quad']
        assert 0 <= downsample_ratio <= 1

        self.loss_module = build_loss(loss)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.text_repr_type = text_repr_type
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.downsample_ratio = downsample_ratio
        if loss['type'] == 'PANLoss':
            self.decoding_type = 'pan'
        elif loss['type'] == 'PSELoss':
            self.decoding_type = 'pse'
        else:
            type = loss['type']
            raise NotImplementedError(f'unsupported loss type {type}.')

        self.out_conv = nn.Conv2d(
            in_channels=np.sum(np.array(in_channels)),
            out_channels=out_channels,
            kernel_size=1)
Exemplo n.º 9
0
    def __init__(self,
                 in_channels,
                 k_at_hops=(8, 4),
                 num_adjacent_linkages=3,
                 node_geo_feat_len=120,
                 pooling_scale=1.0,
                 pooling_output_size=(4, 3),
                 nms_thr=0.3,
                 min_width=8.0,
                 max_width=24.0,
                 comp_shrink_ratio=1.03,
                 comp_ratio=0.4,
                 comp_score_thr=0.3,
                 text_region_thr=0.2,
                 center_region_thr=0.2,
                 center_region_area_thr=50,
                 local_graph_thr=0.7,
                 loss=dict(type='DRRGLoss'),
                 postprocessor=dict(type='DRRGPostprocessor', link_thr=0.85),
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=dict(type='Normal',
                               override=dict(name='out_conv'),
                               mean=0,
                               std=0.01),
                 **kwargs):
        old_keys = ['text_repr_type', 'decoding_type', 'link_thr']
        for key in old_keys:
            if kwargs.get(key, None):
                postprocessor[key] = kwargs.get(key)
                warnings.warn(
                    f'{key} is deprecated, please specify '
                    'it in postprocessor config dict. See '
                    'https://github.com/open-mmlab/mmocr/pull/640'
                    ' for details.', UserWarning)
        BaseModule.__init__(self, init_cfg=init_cfg)
        HeadMixin.__init__(self, loss, postprocessor)

        assert isinstance(in_channels, int)
        assert isinstance(k_at_hops, tuple)
        assert isinstance(num_adjacent_linkages, int)
        assert isinstance(node_geo_feat_len, int)
        assert isinstance(pooling_scale, float)
        assert isinstance(pooling_output_size, tuple)
        assert isinstance(comp_shrink_ratio, float)
        assert isinstance(nms_thr, float)
        assert isinstance(min_width, float)
        assert isinstance(max_width, float)
        assert isinstance(comp_ratio, float)
        assert isinstance(comp_score_thr, float)
        assert isinstance(text_region_thr, float)
        assert isinstance(center_region_thr, float)
        assert isinstance(center_region_area_thr, int)
        assert isinstance(local_graph_thr, float)

        self.in_channels = in_channels
        self.out_channels = 6
        self.downsample_ratio = 1.0
        self.k_at_hops = k_at_hops
        self.num_adjacent_linkages = num_adjacent_linkages
        self.node_geo_feat_len = node_geo_feat_len
        self.pooling_scale = pooling_scale
        self.pooling_output_size = pooling_output_size
        self.comp_shrink_ratio = comp_shrink_ratio
        self.nms_thr = nms_thr
        self.min_width = min_width
        self.max_width = max_width
        self.comp_ratio = comp_ratio
        self.comp_score_thr = comp_score_thr
        self.text_region_thr = text_region_thr
        self.center_region_thr = center_region_thr
        self.center_region_area_thr = center_region_area_thr
        self.local_graph_thr = local_graph_thr
        self.loss_module = build_loss(loss)
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.out_conv = nn.Conv2d(in_channels=self.in_channels,
                                  out_channels=self.out_channels,
                                  kernel_size=1,
                                  stride=1,
                                  padding=0)

        self.graph_train = LocalGraphs(
            self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
            self.pooling_scale, self.pooling_output_size, self.local_graph_thr)

        self.graph_test = ProposalLocalGraphs(
            self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
            self.pooling_scale, self.pooling_output_size, self.nms_thr,
            self.min_width, self.max_width, self.comp_shrink_ratio,
            self.comp_ratio, self.comp_score_thr, self.text_region_thr,
            self.center_region_thr, self.center_region_area_thr)

        pool_w, pool_h = self.pooling_output_size
        node_feat_len = (pool_w * pool_h) * (
            self.in_channels + self.out_channels) + self.node_geo_feat_len
        self.gcn = GCN(node_feat_len)
Exemplo n.º 10
0
    def __init__(self,
                 in_channels,
                 k_at_hops=(8, 4),
                 num_adjacent_linkages=3,
                 node_geo_feat_len=120,
                 pooling_scale=1.0,
                 pooling_output_size=(4, 3),
                 nms_thr=0.3,
                 min_width=8.0,
                 max_width=24.0,
                 comp_shrink_ratio=1.03,
                 comp_ratio=0.4,
                 comp_score_thr=0.3,
                 text_region_thr=0.2,
                 center_region_thr=0.2,
                 center_region_area_thr=50,
                 local_graph_thr=0.7,
                 link_thr=0.85,
                 loss=dict(type='DRRGLoss'),
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=dict(type='Normal',
                               override=dict(name='out_conv'),
                               mean=0,
                               std=0.01)):
        super().__init__(init_cfg=init_cfg)

        assert isinstance(in_channels, int)
        assert isinstance(k_at_hops, tuple)
        assert isinstance(num_adjacent_linkages, int)
        assert isinstance(node_geo_feat_len, int)
        assert isinstance(pooling_scale, float)
        assert isinstance(pooling_output_size, tuple)
        assert isinstance(comp_shrink_ratio, float)
        assert isinstance(nms_thr, float)
        assert isinstance(min_width, float)
        assert isinstance(max_width, float)
        assert isinstance(comp_ratio, float)
        assert isinstance(comp_score_thr, float)
        assert isinstance(text_region_thr, float)
        assert isinstance(center_region_thr, float)
        assert isinstance(center_region_area_thr, int)
        assert isinstance(local_graph_thr, float)
        assert isinstance(link_thr, float)

        self.in_channels = in_channels
        self.out_channels = 6
        self.downsample_ratio = 1.0
        self.k_at_hops = k_at_hops
        self.num_adjacent_linkages = num_adjacent_linkages
        self.node_geo_feat_len = node_geo_feat_len
        self.pooling_scale = pooling_scale
        self.pooling_output_size = pooling_output_size
        self.comp_shrink_ratio = comp_shrink_ratio
        self.nms_thr = nms_thr
        self.min_width = min_width
        self.max_width = max_width
        self.comp_ratio = comp_ratio
        self.comp_score_thr = comp_score_thr
        self.text_region_thr = text_region_thr
        self.center_region_thr = center_region_thr
        self.center_region_area_thr = center_region_area_thr
        self.local_graph_thr = local_graph_thr
        self.link_thr = link_thr
        self.loss_module = build_loss(loss)
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.out_conv = nn.Conv2d(in_channels=self.in_channels,
                                  out_channels=self.out_channels,
                                  kernel_size=1,
                                  stride=1,
                                  padding=0)

        self.graph_train = LocalGraphs(
            self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
            self.pooling_scale, self.pooling_output_size, self.local_graph_thr)

        self.graph_test = ProposalLocalGraphs(
            self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
            self.pooling_scale, self.pooling_output_size, self.nms_thr,
            self.min_width, self.max_width, self.comp_shrink_ratio,
            self.comp_ratio, self.comp_score_thr, self.text_region_thr,
            self.center_region_thr, self.center_region_area_thr)

        pool_w, pool_h = self.pooling_output_size
        node_feat_len = (pool_w * pool_h) * (
            self.in_channels + self.out_channels) + self.node_geo_feat_len
        self.gcn = GCN(node_feat_len)
Exemplo n.º 11
0
    def __init__(self, loss, postprocessor):
        assert isinstance(loss, dict)
        assert isinstance(postprocessor, dict)

        self.loss_module = build_loss(loss)
        self.postprocessor = build_postprocessor(postprocessor)