Beispiel #1
0
    def __init__(self, opt, lstm_layer=2):
        super().__init__()
        # self.block = nn.Sequential(
        #     BidirectionalLSTM(512, opt.hidden_size, opt.hidden_size),
        #     BidirectionalLSTM(opt.hidden_size, opt.hidden_size, 512),
        # )
        self.block = nn.Sequential()
        for i in range(lstm_layer):
            if i == 0:
                self.block.add_module(
                    str(i),
                    BidirectionalLSTM(512, opt.hidden_size, opt.hidden_size))
            elif i == lstm_layer - 1:
                self.block.add_module(
                    str(i),
                    BidirectionalLSTM(opt.hidden_size, opt.hidden_size, 512))
            else:
                self.block.add_module(
                    str(i),
                    BidirectionalLSTM(opt.hidden_size, opt.hidden_size,
                                      opt.hidden_size))

        # LSTM_layer = []
        # for i in range(lstm_layer):
        #     if i == 0:
        #         LSTM_layer.append(BidirectionalLSTM(2048, opt.hidden_size, opt.hidden_size))
        #     elif i == lstm_layer - 1:
        #         LSTM_layer.append(BidirectionalLSTM(opt.hidden_size, opt.hidden_size, 2048))
        #     else:
        #         LSTM_layer.append(BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
        # self.block = nn.ModuleList(LSTM_layer)

        self.sd = Selective_Decoder()
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.stages = { 'Feat': opt.FeatureExtraction,'Seq': opt.SequenceModeling, 'Pred': "CTC-Attn"}
        """ FeatureExtraction """
        if opt.FeatureExtraction == 'ResNet':
            self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == 'CRNN':
            self.FeatureExtraction = CRNN()
            opt.output_channel=512
        else:
            raise Exception('No FeatureExtraction module specified')
        self.FeatureExtraction_output = opt.output_channel  # int(imgH/16-1) * 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))  # Transform final (imgH/16-1) -> 1

        """ Sequence modeling"""
        if opt.SequenceModeling == 'BiLSTM':
            self.SequenceModeling = nn.Sequential(
                BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
                BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
            self.SequenceModeling_output = opt.hidden_size
        else:
            print('No SequenceModeling module specified')
            self.SequenceModeling_output = self.FeatureExtraction_output

        """ Prediction """
        self.CTC_Prediction = nn.Linear(self.SequenceModeling_output, opt.ctc_num_class)
        self.Attn_Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class)
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.stages = {
            'Trans': opt.Transformation,
            'Feat': opt.FeatureExtraction,
            'Seq': opt.SequenceModeling,
            'Pred': opt.Prediction
        }
        """ Transformation """
        if opt.Transformation == 'TPS':
            self.Transformation = TPS_SpatialTransformerNetwork(
                F=opt.num_fiducial,
                I_size=(opt.imgH, opt.imgW),
                I_r_size=(opt.imgH, opt.imgW),
                I_channel_num=opt.input_channel)
        else:
            print('No Transformation module specified')
        """ FeatureExtraction """
        if opt.FeatureExtraction == 'VGG':
            self.FeatureExtraction = VGG_FeatureExtractor(
                opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == 'RCNN':
            self.FeatureExtraction = RCNN_FeatureExtractor(
                opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == 'ResNet':
            self.FeatureExtraction = ResNet_FeatureExtractor(
                opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == 'DenseNet':
            self.FeatureExtraction = DenseNet(opt.input_channel,
                                              opt.output_channel)
        else:
            raise Exception('No FeatureExtraction module specified')

        if opt.FeatureExtraction == 'DenseNet':
            self.FeatureExtraction_output = 768  # DenseNet output channel is 768
        else:
            self.FeatureExtraction_output = opt.output_channel  # int(imgH/16-1) * 512
        # self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))  # Transform final (imgH/16-1) -> 1
        """ Sequence modeling"""
        if opt.SequenceModeling == 'BiLSTM':
            self.SequenceModeling = nn.Sequential(
                BidirectionalLSTM(self.FeatureExtraction_output,
                                  opt.hidden_size, opt.hidden_size),
                BidirectionalLSTM(opt.hidden_size, opt.hidden_size,
                                  opt.hidden_size))
            self.SequenceModeling_output = opt.hidden_size
        else:
            print('No SequenceModeling module specified')
            self.SequenceModeling_output = self.FeatureExtraction_output
        """ Prediction """
        if opt.Prediction == 'CTC':
            self.Prediction = nn.Linear(self.SequenceModeling_output,
                                        opt.num_class)
        elif opt.Prediction == 'Attn':
            self.Prediction = Attention(self.SequenceModeling_output,
                                        opt.hidden_size, opt.num_class)
        else:
            raise Exception('Prediction is neither CTC or Attn')
Beispiel #4
0
 def __init__(self, opt):
     super(Model, self).__init__()
     self.opt = opt
     self.stages = {
         'Trans': opt.Transformation,
         'Feat': opt.FeatureExtraction,
         'Seq': opt.SequenceModeling,
         'Pred': opt.Prediction
     }
     """ 1. 空间转换层: Transformation """
     if opt.Transformation == 'TPS':
         self.Transformation = TPS_SpatialTransformerNetwork(  # Spatial: 空间
             F=opt.num_fiducial,
             I_size=(opt.imgH, opt.imgW),
             I_r_size=(opt.imgH, opt.imgW),
             I_channel_num=opt.input_channel)  # fiducial 基准
     else:
         print('No Transformation module specified')
     """ 2. 特征抽取层: FeatureExtraction """
     if opt.FeatureExtraction == 'VGG':
         # opt.input_channel: 输入图片的深度 depth,  opt.output_channel: 512
         self.FeatureExtraction = VGG_FeatureExtractor(
             opt.input_channel,
             opt.output_channel)  # out: batch, 512, 1, 24
     elif opt.FeatureExtraction == 'RCNN':
         self.FeatureExtraction = RCNN_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     elif opt.FeatureExtraction == 'ResNet':
         self.FeatureExtraction = ResNet_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     else:
         raise Exception('No FeatureExtraction module specified')
     self.FeatureExtraction_output = opt.output_channel  # int(imgH/16-1) * 512   opt.output_channel: default 512
     # 512是特征长度, 24是序列长度
     self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
         (None, 1))  # Transform final (imgH/16-1) -> 1
     """ 3. 序列模型层: Sequence modeling"""
     if opt.SequenceModeling == 'BiLSTM':
         self.SequenceModeling = nn.Sequential(
             BidirectionalLSTM(
                 self.FeatureExtraction_output, opt.hidden_size,
                 opt.hidden_size),  # 序列特征长度: 512, 隐藏: 256, 输出层: 256
             BidirectionalLSTM(opt.hidden_size, opt.hidden_size,
                               opt.hidden_size))
         self.SequenceModeling_output = opt.hidden_size
     else:
         print('No SequenceModeling module specified')
         self.SequenceModeling_output = self.FeatureExtraction_output
     """ 4. 预测层: Prediction """
     if opt.Prediction == 'CTC':
         self.Prediction = nn.Linear(self.SequenceModeling_output,
                                     opt.num_class)
     elif opt.Prediction == 'Attn':
         self.Prediction = Attention(self.SequenceModeling_output,
                                     opt.hidden_size, opt.num_class)
     else:
         raise Exception('Prediction is neither CTC or Attn')
 def __init__(self, FeatureExtraction='ResNet', PAD=False, Prediction='CTC', \
              SequenceModeling='BiLSTM', Transformation='TPS', batch_max_length=25, \
              batch_size=192, character='0123456789abcdefghijklmnopqrstuvwxyz', \
              hidden_size=256, image_folder='data\\OCR-data-labelled\\OCR-data\\validate\\val', \
              imgH=32, imgW=100, input_channel=1, num_class=37, num_fiducial=20, num_gpu=1, \
              output_channel=512, rgb=False, saved_model='saved_models/TPS-ResNet-BiLSTM-CTC-Seed1111/best_accuracy.pth', \
              sensitive=False, workers=4):
     super(Model, self).__init__()
     self.stages = {
         'Trans': Transformation,
         'Feat': FeatureExtraction,
         'Seq': SequenceModeling,
         'Pred': Prediction
     }
     """ Transformation """
     if Transformation == 'TPS':
         self.Transformation = TPS_SpatialTransformerNetwork(
             F=num_fiducial,
             I_size=(imgH, imgW),
             I_r_size=(imgH, imgW),
             I_channel_num=input_channel)
     else:
         print('No Transformation module specified')
     """ FeatureExtraction """
     if FeatureExtraction == 'VGG':
         self.FeatureExtraction = VGG_FeatureExtractor(
             input_channel, output_channel)
     elif FeatureExtraction == 'RCNN':
         self.FeatureExtraction = RCNN_FeatureExtractor(
             input_channel, output_channel)
     elif FeatureExtraction == 'ResNet':
         self.FeatureExtraction = ResNet_FeatureExtractor(
             input_channel, output_channel)
     else:
         raise Exception('No FeatureExtraction module specified')
     self.FeatureExtraction_output = output_channel  # int(imgH/16-1) * 512
     self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
         (None, 1))  # Transform final (imgH/16-1) -> 1
     """ Sequence modeling"""
     if SequenceModeling == 'BiLSTM':
         self.SequenceModeling = nn.Sequential(
             BidirectionalLSTM(self.FeatureExtraction_output, hidden_size,
                               hidden_size),
             BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
         self.SequenceModeling_output = hidden_size
     else:
         print('No SequenceModeling module specified')
         self.SequenceModeling_output = self.FeatureExtraction_output
     """ Prediction """
     if Prediction == 'CTC':
         self.Prediction = nn.Linear(self.SequenceModeling_output,
                                     num_class)
     elif Prediction == 'Attn':
         self.Prediction = Attention(self.SequenceModeling_output,
                                     hidden_size, num_class)
     else:
         raise Exception('Prediction is neither CTC or Attn')
Beispiel #6
0
 def __init__(self, cfg):
     super(OCR_MODEL, self).__init__()
     self.cfg = cfg
     self.stages = {
         'Trans': self.cfg.OCR.TRANSFORMATION,
         'Feat': self.cfg.OCR.FEATURE_EXTRACTION,
         'Seq': self.cfg.OCR.SEQUENCE_MODELING,
         'Pred': self.cfg.OCR.PREDICTION
     }
     """ Transformation """
     if self.cfg.OCR.TRANSFORMATION == 'TPS':
         self.Transformation = TPS_SpatialTransformerNetwork(
             F=self.cfg.OCR.NUM_FIDUCIAL,
             I_size=(cfg.BASE.IMG_H, cfg.BASE.IMG_W),
             I_r_size=(cfg.BASE.IMG_H, cfg.BASE.IMG_W),
             I_channel_num=self.cfg.OCR.INPUT_CHANNEL)
     else:
         print('No Transformation module specified')
     """ FeatureExtraction """
     if self.cfg.OCR.FEATURE_EXTRACTION == 'VGG':
         self.FeatureExtraction = VGG_FeatureExtractor(
             self.cfg.OCR.INPUT_CHANNEL, self.cfg.OCR.OUTPUT_CHANNEL)
     elif self.cfg.OCR.FEATURE_EXTRACTION == 'RCNN':
         self.FeatureExtraction = RCNN_FeatureExtractor(
             self.cfg.OCR.INPUT_CHANNEL, self.cfg.OCR.OUTPUT_CHANNEL)
     elif self.cfg.OCR.FEATURE_EXTRACTION == 'ResNet':
         self.FeatureExtraction = ResNet_FeatureExtractor(
             self.cfg.OCR.INPUT_CHANNEL, self.cfg.OCR.OUTPUT_CHANNEL)
     else:
         raise Exception('No FeatureExtraction module specified')
     self.FeatureExtraction_output = self.cfg.OCR.OUTPUT_CHANNEL  # int(imgH/16-1) * 512
     self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
         (None, 1))  # Transform final (imgH/16-1) -> 1
     """ Sequence modeling"""
     if self.cfg.OCR.SEQUENCE_MODELING == 'BiLSTM':
         self.SequenceModeling = nn.Sequential(
             BidirectionalLSTM(self.FeatureExtraction_output,
                               self.cfg.OCR.HIDDEN_SIZE,
                               self.cfg.OCR.HIDDEN_SIZE),
             BidirectionalLSTM(self.cfg.OCR.HIDDEN_SIZE,
                               self.cfg.OCR.HIDDEN_SIZE,
                               self.cfg.OCR.HIDDEN_SIZE))
         self.SequenceModeling_output = self.cfg.OCR.HIDDEN_SIZE
     else:
         print('No SequenceModeling module specified')
         self.SequenceModeling_output = self.FeatureExtraction_output
     """ Prediction """
     if self.cfg.OCR.PREDICTION == 'CTC':
         self.Prediction = nn.Linear(self.SequenceModeling_output,
                                     self.cfg.OCR.NUM_CLASS)
     elif self.cfg.OCR.PREDICTION == 'Attn':
         self.Prediction = Attention(self.SequenceModeling_output,
                                     self.cfg.OCR.HIDDEN_SIZE,
                                     self.cfg.OCR.NUM_CLASS)
     else:
         raise Exception('Prediction is neither CTC or Attn')
Beispiel #7
0
 def __init__(self, opt):
     super(Model, self).__init__()
     self.opt = opt
     self.transform = TPS_SpatialTransformerNetwork(opt["fiducial_num"],
                                                    opt["img_size"],
                                                    opt["img_r_size"],
                                                    opt["img_channel"])
     self.features_extractor = ResNet_FeatureExtractor(opt["img_channel"])
     self.GAP = nn.AdaptiveAvgPool2d((None, 1))
     self.SequenceModeling = nn.Sequential(
         BidirectionalLSTM(512, opt["h_dim"], opt["h_dim"]),
         BidirectionalLSTM(opt["h_dim"], opt["h_dim"], opt["h_dim"]))
     self.Prediction = Attention(opt["h_dim"], opt["h_dim"], opt["out_dim"])
Beispiel #8
0
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
                       'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}

        """ Transformation """
        if opt.Transformation == 'TPS':
            self.Transformation = TPS_SpatialTransformerNetwork(
                F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
        else:
            print('No Transformation module specified')

        """ FeatureExtraction """
        if opt.FeatureExtraction == 'VGG':
            self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == 'RCNN':
            self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == 'ResNet':
            self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == 'ResNet_PyTorch':
            self.FeatureExtraction = models.resnet34(pretrained=True)
            #Freeze all the parameters of the feature extractor
            if opt.freeze_FeatureExtraction:
                print("Freezing the parameters of the resnet !")
                for param in self.FeatureExtraction.parameters():
                  param.requires_grad = False
            # delete the avgpool and the fc layer of the resnet
            self.FeatureExtraction = nn.Sequential(*(list(self.FeatureExtraction.children())[:-2]))
        else:
            raise Exception('No FeatureExtraction module specified')
        self.FeatureExtraction_output = opt.output_channel  # int(imgH/16-1) * 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))  # Transform final (imgH/16-1) -> 1

        """ Sequence modeling"""
        if opt.SequenceModeling == 'BiLSTM':
            self.SequenceModeling = nn.Sequential(
                BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
                BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
            self.SequenceModeling_output = opt.hidden_size
        else:
            print('No SequenceModeling module specified')
            self.SequenceModeling_output = self.FeatureExtraction_output

        """ Prediction """
        if opt.Prediction == 'CTC':
            self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
        elif opt.Prediction == 'Attn':
            self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class)
        else:
            raise Exception('Prediction is neither CTC or Attn')
Beispiel #9
0
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.stages = {'Trans': opt['Transformation'], 'Feat': opt['FeatureExtraction'],
                       'Seq': opt['SequenceModeling'], 'Pred': opt['Prediction']}

        """ Transformation """
        if opt['Transformation'] == 'TPS':
            self.Transformation = TPS_SpatialTransformerNetwork(
                F=opt['num_fiducial'], I_size=(opt['imgH'], opt['imgW']), I_r_size=(opt['imgH'], opt['imgW']),
                I_channel_num=opt['input_channel'])
                #F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), batch_size=int(opt.batch_size/opt.num_gpu), I_channel_num=opt.input_channel)
        else:
            print('変換モジュールが指定されていません')

        """ FeatureExtraction """
        if opt['FeatureExtraction'] == 'VGG':
            self.FeatureExtraction = VGG_FeatureExtractor(opt['input_channel'], opt['output_channel'])
        elif opt['FeatureExtraction'] == 'RCNN':
            self.FeatureExtraction = RCNN_FeatureExtractor(opt['input_channel'], opt['output_channel'])
        elif opt['FeatureExtraction'] == 'ResNet':
            self.FeatureExtraction = ResNet_FeatureExtractor(opt['input_channel'], opt['output_channel'])
        else:
            raise Exception('FeatureExtractionモジュールが指定されていません')
        self.FeatureExtraction_output = opt['output_channel']  # int(imgH/16-1) * 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))  # Transform final (imgH/16-1) -> 1

        """ Sequence modeling"""
        if opt['SequenceModeling'] == 'BiLSTM':
            self.SequenceModeling = nn.Sequential(
                BidirectionalLSTM(self.FeatureExtraction_output, opt['hidden_size'], opt['hidden_size']),
                BidirectionalLSTM(opt['hidden_size'], opt['hidden_size'], opt['hidden_size']))
            self.SequenceModeling_output = opt['hidden_size']
        else:
            print('SequenceModelingモジュールが指定されていません')
            self.SequenceModeling_output = self.FeatureExtraction_output

        """ Prediction """
        if opt['Prediction'] == 'CTC':
            self.Prediction = nn.Linear(self.SequenceModeling_output, opt['num_class'])
        elif opt['Prediction'] == 'Attn':
            self.Prediction = Attention(self.SequenceModeling_output, opt['hidden_size'], opt['num_class'])
        else:
            raise Exception('予測がCTCでもAttnでもありませんでした。')
Beispiel #10
0
    def __init__(self ,opt,num_class):
        super(Model, self).__init__()
        self.opt = opt

        """ Transformation TPS """
        self.Transformation = TPS_SpatialTransformerNetwork(
            F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)

        """ FeatureExtraction ResNet"""
        self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
        self.FeatureExtraction_output = opt.output_channel  # int(imgH/16-1) * 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))  # Transform final (imgH/16-1) -> 1

        """ Sequence modeling BiLSTM"""
        self.SequenceModeling = nn.Sequential(
            BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
            BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
        self.SequenceModeling_output = opt.hidden_size

        """ Prediction Attn"""
        self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class ,num_class)
Beispiel #11
0
    def __init__(self, imgW, imgH, num_class, batch_max_length):
        super(Model, self).__init__()
        num_fiducial = 20
        input_channel = 1
        output_channel = 512
        hidden_size = 256
        self.batch_max_length = batch_max_length

        self.stages = {
            'Trans': "TPS",
            'Feat': "ResNet",
            'Seq': "BiLSTM",
            'Pred': "Attn"
        }
        """ Transformation """
        self.Transformation = TPS_SpatialTransformerNetwork(
            F=num_fiducial,
            I_size=(imgH, imgW),
            I_r_size=(imgH, imgW),
            I_channel_num=input_channel)
        """ FeatureExtraction """
        self.FeatureExtraction = ResNet_FeatureExtractor(
            input_channel, output_channel)

        self.FeatureExtraction_output = output_channel  # int(imgH/16-1) * 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
            (None, 1))  # Transform final (imgH/16-1) -> 1
        """ Sequence modeling"""
        self.SequenceModeling = nn.Sequential(
            BidirectionalLSTM(self.FeatureExtraction_output, hidden_size,
                              hidden_size),
            BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
        self.SequenceModeling_output = hidden_size
        """ Prediction """
        self.Prediction = Attention(self.SequenceModeling_output, hidden_size,
                                    num_class)
 def __init__(self, opt):
     super(Model, self).__init__()
     self.opt = opt
     self.stages = {
         'Trans': opt.Transformation,
         'Feat': opt.FeatureExtraction,
         'Seq': opt.SequenceModeling,
         'Pred': opt.Prediction
     }
     """ Transformation """
     if opt.Transformation == 'TPS':
         self.Transformation = TPS_SpatialTransformerNetwork(
             F=opt.num_fiducial,
             I_size=(opt.imgH, opt.imgW),
             I_r_size=(opt.imgH, opt.imgW),
             I_channel_num=opt.input_channel)
     else:
         print('No Transformation module specified')
     """ FeatureExtraction """
     map_backbon_to_len_sequence = {'VGG': 24, 'RCNN': 26, 'ResNet': 26}
     self.len_sequence = map_backbon_to_len_sequence[opt.FeatureExtraction]
     if opt.FeatureExtraction == 'VGG':
         self.FeatureExtraction = VGG_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     elif opt.FeatureExtraction == 'RCNN':
         self.FeatureExtraction = RCNN_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     elif opt.FeatureExtraction == 'ResNet':
         self.FeatureExtraction = ResNet_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     else:
         raise Exception('No FeatureExtraction module specified')
     self.FeatureExtraction_output = opt.output_channel  # int(imgH/16-1) * 512
     self.GraphConvolution_output = opt.output_channel_GCN
     self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
         (None, 1))  # Transform final (imgH/16-1) -> 1
     """ Sequence modeling"""
     if opt.SequenceModeling == 'BiLSTM':
         self.SequenceModeling = nn.Sequential(
             BidirectionalLSTM(self.FeatureExtraction_output,
                               opt.hidden_size, opt.hidden_size),
             BidirectionalLSTM(opt.hidden_size, opt.hidden_size,
                               opt.hidden_size))
         self.SequenceModeling_output = opt.hidden_size
     else:
         if opt.SequenceModeling == 'GCN-BiLSTM':
             self.SequenceModeling = nn.Sequential(
                 GraphConvolution(opt.batch_size,
                                  self.len_sequence,
                                  self.FeatureExtraction_output,
                                  self.GraphConvolution_output,
                                  bias=False,
                                  scale_factor=5),
                 BidirectionalLSTM(self.GraphConvolution_output,
                                   opt.hidden_size, opt.hidden_size),
                 BidirectionalLSTM(opt.hidden_size, opt.hidden_size,
                                   opt.hidden_size))
             self.SequenceModeling_output = opt.hidden_size
         else:
             print('No SequenceModeling module specified')
             self.SequenceModeling_output = self.FeatureExtraction_output
     """ Prediction """
     if opt.Prediction == 'CTC':
         self.Prediction = nn.Linear(self.SequenceModeling_output,
                                     opt.num_class)
     elif opt.Prediction == 'Attn':
         self.Prediction = Attention(self.SequenceModeling_output,
                                     opt.hidden_size, opt.num_class)
     else:
         raise Exception('Prediction is neither CTC or Attn')
Beispiel #13
0
 def __init__(self, opt):
     super(Model, self).__init__()
     self.opt = opt
     self.stages = {
         'Trans': opt.Transformation,
         'Feat': opt.FeatureExtraction,
         'Seq': opt.SequenceModeling,
         'Pred': opt.Prediction
     }
     """ Transformation """
     if opt.Transformation == 'TPS':
         self.Transformation = TPS_SpatialTransformerNetwork(
             F=opt.num_fiducial,
             I_size=(opt.imgH, opt.imgW),
             I_r_size=(opt.imgH, opt.imgW),
             I_channel_num=opt.input_channel)
     else:
         print('No Transformation module specified')
     """ FeatureExtraction """
     if opt.FeatureExtraction == 'VGG':
         self.FeatureExtraction = VGG_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     elif opt.FeatureExtraction == 'RCNN':
         self.FeatureExtraction = RCNN_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     elif opt.FeatureExtraction == 'ResNet':
         self.FeatureExtraction = ResNet_FeatureExtractor(
             opt.input_channel, opt.output_channel)
         self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
             (None, 1))  # Transform final (imgH/16-1) -> 1
     elif opt.FeatureExtraction == 'AsterRes':
         self.FeatureExtraction = ResNet_ASTER2(opt.input_channel,
                                                opt.output_channel)
     elif opt.FeatureExtraction == 'ResnetFpn':
         self.FeatureExtraction = ResNet_FPN()
     else:
         raise Exception('No FeatureExtraction module specified')
     self.FeatureExtraction_output = opt.output_channel  # int(imgH/16-1) * 512
     """ Sequence modeling"""
     if opt.SequenceModeling == 'BiLSTM':
         self.SequenceModeling = nn.Sequential(
             BidirectionalLSTM(self.FeatureExtraction_output,
                               opt.hidden_size, opt.hidden_size),
             BidirectionalLSTM(opt.hidden_size, opt.hidden_size,
                               opt.hidden_size))
         self.SequenceModeling_output = opt.hidden_size
     elif opt.SequenceModeling == 'Bert':
         cfg = Config()
         cfg.dim = opt.output_channel
         cfg.dim_c = opt.output_channel  # 降维减少计算量
         cfg.p_dim = opt.position_dim  # 一张图片cnn编码之后的特征序列长度
         cfg.max_vocab_size = opt.batch_max_length + 1  # 一张图片中最多的文字个数, +1 for EOS
         cfg.len_alphabet = opt.alphabet_size  # 文字的类别个数
         self.SequenceModeling = Bert_Ocr(cfg)
     elif opt.SequenceModeling == 'SRN':
         self.SequenceModeling = Transforme_Encoder()
     else:
         print('No SequenceModeling module specified')
         self.SequenceModeling_output = self.FeatureExtraction_output
     """ Prediction """
     if opt.Prediction == 'CTC':
         self.Prediction = nn.Linear(self.SequenceModeling_output,
                                     opt.num_class)
     elif opt.Prediction == 'Attn':
         self.Prediction = Attention(self.SequenceModeling_output,
                                     opt.hidden_size, opt.num_class)
     elif opt.Prediction == 'Bert_pred':
         pass
     elif opt.Prediction == 'SRN':
         self.Prediction = SRN_Decoder(n_position=opt.n_position)
     else:
         raise Exception('Prediction is neither CTC or Attn')
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.stages = {
            'Trans': opt.Transformation,
            'Feat': opt.FeatureExtraction,
            'Seq': opt.SequenceModeling,
            'Pred': opt.Prediction
        }
        """ Transformation """
        if opt.Transformation == 'TPS':
            self.Transformation = TPS_SpatialTransformerNetwork(
                F=opt.num_fiducial,
                I_size=(opt.imgH, opt.imgW),
                I_r_size=(opt.imgH, opt.imgW),
                I_channel_num=opt.input_channel)
        else:
            print('No Transformation module specified')
        """ FeatureExtraction """
        if opt.FeatureExtraction == 'VGG':
            self.FeatureExtraction = VGG_FeatureExtractor(
                opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == 'RCNN':
            self.FeatureExtraction = RCNN_FeatureExtractor(
                opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == 'ResNet' or 'SEResNet' == opt.FeatureExtraction:
            self.FeatureExtraction = ResNet_FeatureExtractor(
                opt.input_channel, opt.output_channel, opt)
        elif 'SEResNetXt' in opt.FeatureExtraction:
            opt.output_channel = 2048
            self.FeatureExtraction = ResNet_FeatureExtractor(
                opt.input_channel, opt.output_channel, opt)
        else:
            raise Exception('No FeatureExtraction module specified')
        self.FeatureExtraction_output = opt.output_channel  # int(imgH/16-1) * 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
            (None, 1))  # Transform final (imgH/16-1) -> 1

        if opt.dropout > 0:
            self.dropout = nn.Dropout(opt.dropout)
        else:
            self.dropout = None
        """ Sequence modeling"""
        if opt.SequenceModeling == 'BiLSTM':
            self.SequenceModeling = nn.Sequential(
                BidirectionalLSTM(self.FeatureExtraction_output,
                                  opt.hidden_size, opt.hidden_size,
                                  opt.rnnlayers, opt.rnndropout),
                BidirectionalLSTM(opt.hidden_size, opt.hidden_size,
                                  opt.hidden_size, opt.rnnlayers,
                                  opt.rnndropout))
            self.SequenceModeling_output = opt.hidden_size
        elif opt.SequenceModeling == 'BiGRU':
            self.SequenceModeling = nn.Sequential(
                BidirectionalGRU(self.FeatureExtraction_output,
                                 opt.hidden_size, opt.hidden_size,
                                 opt.rnnlayers, opt.rnndropout),
                BidirectionalGRU(opt.hidden_size, opt.hidden_size,
                                 opt.hidden_size, opt.rnnlayers,
                                 opt.rnndropout))
            self.SequenceModeling_output = opt.hidden_size
        else:
            print('No SequenceModeling module specified')
            self.SequenceModeling_output = self.FeatureExtraction_output
        """ Prediction """
        if opt.Prediction == 'CTC':
            self.Prediction = nn.Linear(self.SequenceModeling_output,
                                        opt.num_class + 1)
        elif opt.Prediction == 'Attn':
            self.Prediction = Attention(opt, self.SequenceModeling_output,
                                        opt.hidden_size, opt.num_class + 2)
        elif opt.Prediction == 'CTC_Attn':
            self.Prediction_ctc = nn.Linear(self.SequenceModeling_output,
                                            opt.num_class + 1)
            self.Prediction_attn = Attention(opt, self.SequenceModeling_output,
                                             opt.hidden_size,
                                             opt.num_class + 2)
        else:
            raise Exception('Prediction is neither CTC or Attn')
 def __init__(self,
              imgH,
              imgW,
              input_channel,
              output_channel,
              hidden_size,
              num_class,
              batch_max_length,
              Transformation="None",
              FeatureExtraction="VGG",
              SequenceModeling="BiLSTM",
              Prediction="CTC",
              F=20):
     super(Model, self).__init__()
     self.stages = {
         'Trans': Transformation,
         'Feat': FeatureExtraction,
         'Seq': SequenceModeling,
         'Pred': Prediction
     }
     """ Transformation """
     if Transformation == 'TPS':
         self.Transformation = TPS_SpatialTransformerNetwork(
             F=F,
             I_size=(imgH, imgW),
             I_r_size=(imgH, imgW),
             I_channel_num=input_channel)
     else:
         print('No Transformation module specified')
     """ FeatureExtraction """
     if FeatureExtraction == 'VGG':
         self.FeatureExtraction = VGG_FeatureExtractor(
             input_channel, output_channel)
     elif FeatureExtraction == 'RCNN':
         self.FeatureExtraction = RCNN_FeatureExtractor(
             input_channel, output_channel)
     elif FeatureExtraction == 'ResNet':
         self.FeatureExtraction = ResNet_FeatureExtractor(
             input_channel, output_channel)
     else:
         raise Exception('No FeatureExtraction module specified')
     self.FeatureExtraction_output = output_channel  # int(imgH/16-1) * 512
     self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
         (None, 1))  # Transform final (imgH/16-1) -> 1
     """ Sequence modeling"""
     if SequenceModeling == 'BiLSTM':
         self.SequenceModeling = nn.Sequential(
             BidirectionalLSTM(self.FeatureExtraction_output, hidden_size,
                               hidden_size),
             BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
         self.SequenceModeling_output = hidden_size
     else:
         print('No SequenceModeling module specified')
         self.SequenceModeling_output = self.FeatureExtraction_output
     """ Prediction """
     if Prediction == 'CTC':
         self.Prediction = nn.Linear(self.SequenceModeling_output,
                                     num_class)
     elif Prediction == 'Attn':
         self.Prediction = Attention(self.SequenceModeling_output,
                                     hidden_size, num_class)
     else:
         raise Exception('Prediction is neither CTC or Attn')
Beispiel #16
0
 def __init__(self, opt):
     super(Model, self).__init__()
     self.opt = opt
     self.stages = {
         'Trans': opt.Transformation,
         'Feat': opt.FeatureExtraction,
         'Seq': opt.SequenceModeling,
         'Pred': opt.Prediction
     }
     """ Transformation """
     if opt.Transformation == 'TPS':
         self.Transformation = TPS_SpatialTransformerNetwork(
             F=opt.num_fiducial,
             I_size=(opt.imgH, opt.imgW),
             I_r_size=(opt.imgH, opt.imgW),
             I_channel_num=opt.input_channel)
     else:
         print('No Transformation module specified')
     """ FeatureExtraction """
     map_backbon_to_len_sequence = {'VGG': 24, 'RCNN': 26, 'ResNet': 26}
     self.len_sequence = map_backbon_to_len_sequence[opt.FeatureExtraction]
     if opt.FeatureExtraction == 'VGG':
         self.FeatureExtraction = VGG_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     elif opt.FeatureExtraction == 'RCNN':
         self.FeatureExtraction = RCNN_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     elif opt.FeatureExtraction == 'ResNet':
         self.FeatureExtraction = ResNet_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     else:
         raise Exception('No FeatureExtraction module specified')
     self.FeatureExtraction_output = opt.output_channel  # int(imgH/16-1) * 512
     self.GraphConvolution_output = opt.output_channel_GCN
     self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
         (None, 1))  # Transform final (imgH/16-1) -> 1
     """ Sequence modeling"""
     # if opt.SequenceModeling == 'BiLSTM':
     #     self.SequenceModeling = nn.Sequential(
     #         BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
     #         BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
     #     self.SequenceModeling_output = opt.hidden_size
     # else :
     #     if opt.SequenceModeling == 'GCN-BiLSTM':
     #         self.SequenceModeling = nn.Sequential(
     #         GraphConvolution(opt.batch_size, self.len_sequence, self.FeatureExtraction_output, self.GraphConvolution_output, bias = False, scale_factor = 2.0,dropout = 0.0),
     #         BidirectionalLSTM(self.GraphConvolution_output, opt.hidden_size, opt.hidden_size),
     #         BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
     #         self.SequenceModeling_output = opt.hidden_size
     #     else:
     #         print('No SequenceModeling module specified')
     #         self.SequenceModeling_output = self.FeatureExtraction_output
     """ Prediction """
     # if opt.guide_training :
     self.SequenceModeling_CTC = nn.Sequential(
         WeightDropBiLSTM(self.FeatureExtraction_output,
                          opt.hidden_size,
                          opt.hidden_size,
                          numclass=opt.num_class_ctc,
                          dropouti=0.05,
                          wdrop=0.2,
                          dropouto=0.05),
         # GraphConvolution(opt.batch_size, self.len_sequence, self.FeatureExtraction_output, self.GraphConvolution_output, bias = False,scale_factor = 2.0,dropout = 0.0,isnormalize = True),
         # BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size,dropouti = 0.1, wdrop = 0.2, dropouto = 0.0),
         # BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size,is_last_blstm = True, numclass = opt.num_class_ctc,dropouti = 0.1, wdrop = 0.2, dropouto = 0.1),
         # BidirectionalLSTM(self.GraphConvolution_output, opt.hidden_size, opt.hidden_size,is_last_blstm = True, numclass = opt.num_class_ctc),
         # nn.Dropout(p=0.2),
         # BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size),
     )
     self.SequenceModeling_Attn = nn.Sequential(
         BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size,
                           opt.hidden_size),
         BidirectionalLSTM(opt.hidden_size, opt.hidden_size,
                           opt.hidden_size),
     )
     # self.CTC = nn.Linear(opt.hidden_size, opt.num_class_ctc)
     self.Attention = Attention(opt.hidden_size, opt.hidden_size,
                                opt.num_class_attn)
Beispiel #17
0
 def __init__(self, opt):
     super(Model, self).__init__()
     self.opt = opt
     self.stages = {
         'Trans': opt.Transformation,
         'Feat': opt.FeatureExtraction,
         'Seq': opt.SequenceModeling,
         'Pred': opt.Prediction
     }
     """ Transformation """
     if opt.Transformation == 'TPS':
         self.Transformation = TPS_SpatialTransformerNetwork(
             F=opt.num_fiducial,
             I_size=(opt.imgH, opt.imgW),
             I_r_size=(opt.imgH, opt.imgW),
             I_channel_num=opt.input_channel)
     else:
         print('No Transformation module specified')
     """ FeatureExtraction """
     if opt.FeatureExtraction == 'VGG':
         self.FeatureExtraction = VGG_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     elif opt.FeatureExtraction == 'RCNN':
         self.FeatureExtraction = RCNN_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     elif opt.FeatureExtraction == 'ResNet':
         self.FeatureExtraction = ResNet_FeatureExtractor(
             opt.input_channel, opt.output_channel)
     else:
         raise Exception('No FeatureExtraction module specified')
     self.FeatureExtraction_output = opt.output_channel  # int(imgH/16-1) * 512
     self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
         (None, 1))  # Transform final (imgH/16-1) -> 1
     """ Sequence modeling"""
     if opt.SequenceModeling == 'BiLSTM':
         self.SequenceModeling = nn.Sequential(
             BidirectionalLSTM(self.FeatureExtraction_output,
                               opt.hidden_size, opt.hidden_size),
             BidirectionalLSTM(opt.hidden_size, opt.hidden_size,
                               opt.hidden_size))
         self.SequenceModeling_output = opt.hidden_size
     elif opt.SequenceModeling == 'Transformer':
         self.SequenceModeling = LanguageTransformer(
             vocab_size=opt.num_class,
             input_size=self.FeatureExtraction_output,
             d_model=opt.d_model,
             nhead=opt.nhead,
             num_encoder_layers=opt.num_encoder_layers,
             num_decoder_layers=opt.num_decoder_layers,
             dim_feedforward=opt.dim_feedforward,
             max_seq_length=opt.max_seq_length,
             pos_dropout=opt.pos_dropout,
             trans_dropout=opt.trans_dropout,
         )
     else:
         print('No SequenceModeling module specified')
         self.SequenceModeling_output = self.FeatureExtraction_output
     """ Prediction """
     if opt.Prediction == 'CTC':
         self.Prediction = nn.Linear(self.SequenceModeling_output,
                                     opt.num_class)
     elif opt.Prediction == 'Attn':
         self.Prediction = Attention(self.SequenceModeling_output,
                                     opt.hidden_size, opt.num_class)
     elif opt.Prediction == 'None':
         self.Prediction = noop
     else:
         raise Exception('Prediction is neither CTC or Attn')
Beispiel #18
0
    def __init__(self, cfg):
        super(Model, self).__init__()
        self.cfg = cfg
        self.stages = {'Trans': cfg['model']['transform'], 'Feat': cfg['model']['extraction'],
                       'Seq': cfg['model']['sequence'], 'Pred': cfg['model']['prediction']}
        
        

        """ Transformation """
        if cfg['model']['transform'] == 'TPS':
            self.Transformation = TPS_SpatialTransformerNetwork(
                F=cfg['transform']['num_fiducial'],
                I_size=(cfg['dataset']['imgH'],cfg['dataset']['imgW']),
                I_r_size=(cfg['dataset']['imgH'], cfg['dataset']['imgW']),
                I_channel_num=cfg['model']['input_channel'])
            print ("Transformation moduls : {}".format(cfg['model']['transform']))
            
        else:
            print('No Transformation module specified')
            
            
            
        """ FeatureExtraction """
        if cfg['model']['extraction'] == 'VGG':
            self.FeatureExtraction = VGG_FeatureExtractor(cfg['model']['input_channel'], cfg['model']['output_channel'])
            
        elif cfg['model']['extraction'] == 'RCNN':
            self.FeatureExtraction = RCNN_FeatureExtractor(cfg['model']['input_channel'], cfg['model']['output_channel'])
            
        elif cfg['model']['extraction'] == 'ResNet':
            self.FeatureExtraction = ResNet_FeatureExtractor(cfg['model']['input_channel'], cfg['model']['output_channel'])
            
        else:
            raise Exception('No FeatureExtraction module specified')
            
        self.FeatureExtraction_output = cfg['model']['output_channel']  # int(imgH/16-1) * 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))  # Transform final (imgH/16-1) -> 1
        print ('Feature extractor : {}'.format(cfg['model']['extraction']))

        
        """ Sequence modeling"""
        if cfg['model']['sequence'] == 'BiLSTM':
            self.SequenceModeling = nn.Sequential(
                BidirectionalLSTM(self.FeatureExtraction_output, cfg['model']['hidden_size'], cfg['model']['hidden_size']),
                BidirectionalLSTM(cfg['model']['hidden_size'], cfg['model']['hidden_size'], cfg['model']['hidden_size']))
            self.SequenceModeling_output = cfg['model']['hidden_size']
        
        # SequenceModeling : Transformer
        elif cfg['model']['sequence'] == 'Transformer':
            self.SequenceModeling = Transformer(
                d_model=self.FeatureExtraction_output, 
                nhead=2, 
                num_encoder_layers=2,
                num_decoder_layers=2,
                dim_feedforward=cfg['model']['hidden_size'],
                dropout=0.1,
                activation='relu')
            print('SequenceModeling: Transformer initialized.')
            self.SequenceModeling_output = self.FeatureExtraction_output # 입력의 차원과 같은 차원으로 출력 됨
        
        else:
            print('No SequenceModeling module specified')
            self.SequenceModeling_output = self.FeatureExtraction_output
        print('Sequence modeling : {}'.format(cfg['model']['sequence']))

        
        
        """ Prediction """
        if cfg['model']['prediction'] == 'CTC':
            self.Prediction = nn.Linear(self.SequenceModeling_output, cfg['training']['num_class'])
            
        elif cfg['model']['prediction'] == 'Attn':
            self.Prediction = Attention(self.SequenceModeling_output, cfg['model']['hidden_size'], cfg['training']['num_class'])
            
        elif cfg['model']['prediction'] == 'Transformer':
            self.Prediction = nn.Linear(self.SequenceModeling_output, cfg['training']['num_class'])
            
        else:
            raise Exception('Prediction should be in [CTC | Attn | Transformer]')
        
        print ("Prediction : {}".format(cfg['model']['prediction']))
Beispiel #19
0
    def __init__(self, opt, SelfSL_layer=False):
        super(Model, self).__init__()
        self.opt = opt
        self.stages = {
            "Trans": opt.Transformation,
            "Feat": opt.FeatureExtraction,
            "Seq": opt.SequenceModeling,
            "Pred": opt.Prediction,
        }
        """ Transformation """
        if opt.Transformation == "TPS":
            self.Transformation = TPS_SpatialTransformerNetwork(
                F=opt.num_fiducial,
                I_size=(opt.imgH, opt.imgW),
                I_r_size=(opt.imgH, opt.imgW),
                I_channel_num=opt.input_channel,
            )
        else:
            print("No Transformation module specified")
        """ FeatureExtraction """
        if opt.FeatureExtraction == "VGG":
            self.FeatureExtraction = VGG_FeatureExtractor(
                opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == "RCNN":
            self.FeatureExtraction = RCNN_FeatureExtractor(
                opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == "ResNet":
            self.FeatureExtraction = ResNet_FeatureExtractor(
                opt.input_channel, opt.output_channel)
        else:
            raise Exception("No FeatureExtraction module specified")
        self.FeatureExtraction_output = opt.output_channel
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
            (None, 1))  # Transform final (imgH/16-1) -> 1

        if not SelfSL_layer:  # for STR
            """Sequence modeling"""
            if opt.SequenceModeling == "BiLSTM":
                self.SequenceModeling = nn.Sequential(
                    BidirectionalLSTM(self.FeatureExtraction_output,
                                      opt.hidden_size, opt.hidden_size),
                    BidirectionalLSTM(opt.hidden_size, opt.hidden_size,
                                      opt.hidden_size),
                )
                self.SequenceModeling_output = opt.hidden_size
            else:
                print("No SequenceModeling module specified")
                self.SequenceModeling_output = self.FeatureExtraction_output

        if not SelfSL_layer:  # for STR.
            """Prediction"""
            if opt.Prediction == "CTC":
                self.Prediction = nn.Linear(self.SequenceModeling_output,
                                            opt.num_class)
            elif opt.Prediction == "Attn":
                self.Prediction = Attention(self.SequenceModeling_output,
                                            opt.hidden_size, opt.num_class)
            else:
                raise Exception("Prediction is neither CTC or Attn")

        else:
            """for self-supervised learning (SelfSL)"""
            self.AdaptiveAvgPool_2 = nn.AdaptiveAvgPool2d(
                (None, 1))  # make width -> 1
            if SelfSL_layer == "CNN":
                self.SelfSL_FFN_input = self.FeatureExtraction_output

            if "RotNet" in self.opt.self:
                self.SelfSL = nn.Linear(self.SelfSL_FFN_input,
                                        4)  # 4 = [0, 90, 180, 270] degrees
            elif "MoCo" in self.opt.self:
                self.SelfSL = nn.Linear(self.SelfSL_FFN_input,
                                        128)  # 128 is used for MoCo paper.
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.stages = {
            'Trans': opt.Transformation,
            'Feat': opt.FeatureExtraction,
            'Seq': opt.SequenceModeling,
            'Pred': opt.Prediction
        }
        """ Transformation """
        if opt.Transformation == 'TPS':
            self.Transformation = TPS_SpatialTransformerNetwork(
                F=opt.num_fiducial,
                I_size=(opt.imgH, opt.imgW),
                I_r_size=(opt.imgH, opt.imgW),
                I_channel_num=opt.input_channel)
        else:
            print('No Transformation module specified')
        """ FeatureExtraction """
        if opt.FeatureExtraction == 'VGG':
            self.FeatureExtraction = VGG_FeatureExtractor(
                opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == 'RCNN':
            self.FeatureExtraction = RCNN_FeatureExtractor(
                opt.input_channel, opt.output_channel)
        elif opt.FeatureExtraction == 'ResNet':
            self.FeatureExtraction = ResNet_FeatureExtractor(
                opt.input_channel, opt.output_channel)
        else:
            raise Exception('No FeatureExtraction module specified')
        self.FeatureExtraction_output = opt.output_channel  # int(imgH/16-1) * 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
            (None, 1))  # Transform final (imgH/16-1) -> 1
        """ Sequence modeling"""
        if opt.SequenceModeling == 'BiLSTM':
            self.SequenceModeling = nn.Sequential(
                BidirectionalLSTM(self.FeatureExtraction_output,
                                  opt.hidden_size, opt.hidden_size),
                BidirectionalLSTM(opt.hidden_size, opt.hidden_size,
                                  opt.hidden_size))
            self.SequenceModeling_output = opt.hidden_size
        else:
            print('No SequenceModeling module specified')
            self.SequenceModeling_output = self.FeatureExtraction_output
        """ Prediction """
        if opt.Prediction == 'CTC':
            self.Prediction = nn.Linear(self.SequenceModeling_output,
                                        opt.num_class)
        elif opt.Prediction == 'Attn':
            self.Prediction = Attention(self.SequenceModeling_output,
                                        opt.hidden_size, opt.num_class)
        elif opt.Prediction == 'Transformer':
            device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')
            OUTPUT_DIM = opt.num_class
            HID_DIM = self.SequenceModeling_output
            ENC_LAYERS = 3
            DEC_LAYERS = 3
            ENC_HEADS = 8
            DEC_HEADS = 8
            ENC_PF_DIM = 512
            DEC_PF_DIM = 512
            ENC_DROPOUT = 0.1
            DEC_DROPOUT = 0.1

            enc = Encoder(HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM,
                          ENC_DROPOUT, device)

            dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS,
                          DEC_PF_DIM, DEC_DROPOUT, device)

            TRG_PAD_IDX = 2  #TRG.vocab.stoi[TRG.pad_token]

            self.Prediction = Seq2Seq(enc, dec, TRG_PAD_IDX, device).to(device)
            self.Prediction.apply(self.initialize_weights)
            print("use transformer")
        else:
            raise Exception('Prediction is neither CTC or Attn')