def __init__(self,
                 preprocessor=None,
                 backbone=None,
                 encoder=None,
                 decoder=None,
                 loss=None,
                 label_convertor=None,
                 train_cfg=None,
                 test_cfg=None,
                 max_seq_len=40,
                 pretrained=None,
                 init_cfg=None):

        super().__init__(init_cfg=init_cfg)

        # Label convertor (str2tensor, tensor2str)
        assert label_convertor is not None
        label_convertor.update(max_seq_len=max_seq_len)
        self.label_convertor = build_convertor(label_convertor)

        # Preprocessor module, e.g., TPS
        self.preprocessor = None
        if preprocessor is not None:
            self.preprocessor = build_preprocessor(preprocessor)

        # Backbone
        assert backbone is not None
        self.backbone = build_backbone(backbone)

        # Encoder module
        self.encoder = None
        if encoder is not None:
            self.encoder = build_encoder(encoder)

        # Decoder module
        assert decoder is not None
        decoder.update(num_classes=self.label_convertor.num_classes())
        decoder.update(start_idx=self.label_convertor.start_idx)
        decoder.update(padding_idx=self.label_convertor.padding_idx)
        decoder.update(max_seq_len=max_seq_len)
        self.decoder = build_decoder(decoder)

        # Loss
        assert loss is not None
        loss.update(ignore_index=self.label_convertor.padding_idx)
        self.loss = build_loss(loss)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.max_seq_len = max_seq_len

        if pretrained is not None:
            warnings.warn('DeprecationWarning: pretrained is a deprecated \
                key, please consider using init_cfg')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
Ejemplo n.º 2
0
    def __init__(self,
                 preprocessor=None,
                 backbone=None,
                 encoder=None,
                 decoder=None,
                 loss=None,
                 label_convertor=None,
                 train_cfg=None,
                 test_cfg=None,
                 max_seq_len=40,
                 pretrained=None):
        super().__init__()

        # Label convertor (str2tensor, tensor2str)
        assert label_convertor is not None
        label_convertor.update(max_seq_len=max_seq_len)
        self.label_convertor = build_convertor(label_convertor)

        # Preprocessor module, e.g., TPS
        self.preprocessor = None
        if preprocessor is not None:
            self.preprocessor = build_preprocessor(preprocessor)

        # Backbone
        assert backbone is not None
        self.backbone = build_backbone(backbone)

        # Encoder module
        self.encoder = None
        if encoder is not None:
            self.encoder = build_encoder(encoder)

        # Decoder module
        assert decoder is not None
        decoder.update(num_classes=self.label_convertor.num_classes())
        decoder.update(start_idx=self.label_convertor.start_idx)
        decoder.update(padding_idx=self.label_convertor.padding_idx)
        decoder.update(max_seq_len=max_seq_len)
        self.decoder = build_decoder(decoder)

        # Loss
        assert loss is not None
        loss.update(ignore_index=self.label_convertor.padding_idx)
        self.loss = build_loss(loss)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.max_seq_len = max_seq_len
        self.init_weights(pretrained=pretrained)
Ejemplo n.º 3
0
    def __init__(self,
                 preprocessor=None,
                 backbone=None,
                 neck=None,
                 head=None,
                 loss=None,
                 label_convertor=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)

        # Label_convertor
        assert label_convertor is not None
        self.label_convertor = build_convertor(label_convertor)

        # Preprocessor module, e.g., TPS
        self.preprocessor = None
        if preprocessor is not None:
            self.preprocessor = build_preprocessor(preprocessor)

        # Backbone
        assert backbone is not None
        self.backbone = build_backbone(backbone)

        # Neck
        assert neck is not None
        self.neck = build_neck(neck)

        # Head
        assert head is not None
        head.update(num_classes=self.label_convertor.num_classes())
        self.head = build_head(head)

        # Loss
        assert loss is not None
        self.loss = build_loss(loss)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        if pretrained is not None:
            warnings.warn('DeprecationWarning: pretrained is a deprecated \
                key, please consider using init_cfg')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
Ejemplo n.º 4
0
    def __init__(self,
                 encoder,
                 decoder,
                 loss,
                 label_convertor,
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.label_convertor = build_convertor(label_convertor)

        self.encoder = build_encoder(encoder)

        decoder.update(num_labels=self.label_convertor.num_labels)
        self.decoder = build_decoder(decoder)

        loss.update(num_labels=self.label_convertor.num_labels)
        self.loss = build_loss(loss)
Ejemplo n.º 5
0
    def __init__(self,
                 label_convertor=None,
                 attn_shrink_ratio=0.5,
                 seg_shrink_ratio=0.25,
                 box_type='char_rects',
                 pad_val=255):

        assert isinstance(attn_shrink_ratio, float)
        assert isinstance(seg_shrink_ratio, float)
        assert 0. < attn_shrink_ratio < 1.0
        assert 0. < seg_shrink_ratio < 1.0
        assert label_convertor is not None
        assert box_type in ('char_rects', 'char_quads')

        self.attn_shrink_ratio = attn_shrink_ratio
        self.seg_shrink_ratio = seg_shrink_ratio
        self.label_convertor = build_convertor(label_convertor)
        self.box_type = box_type
        self.pad_val = pad_val
Ejemplo n.º 6
0
    def __init__(self,
                 preprocessor=None,
                 backbone=None,
                 neck=None,
                 head=None,
                 loss=None,
                 label_convertor=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super().__init__()

        # Label_convertor
        assert label_convertor is not None
        self.label_convertor = build_convertor(label_convertor)

        # Preprocessor module, e.g., TPS
        self.preprocessor = None
        if preprocessor is not None:
            self.preprocessor = build_preprocessor(preprocessor)

        # Backbone
        assert backbone is not None
        self.backbone = build_backbone(backbone)

        # Neck
        assert neck is not None
        self.neck = build_neck(neck)

        # Head
        assert head is not None
        head.update(num_classes=self.label_convertor.num_classes())
        self.head = build_head(head)

        # Loss
        assert loss is not None
        self.loss = build_loss(loss)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.init_weights(pretrained=pretrained)
Ejemplo n.º 7
0
 def __init__(self, label_convertor, max_len):
     self.label_convertor = build_convertor(label_convertor)
     self.max_len = max_len