def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = { 'Feat': opt.FeatureExtraction,'Seq': opt.SequenceModeling, 'Pred': "CTC-Attn"} """ FeatureExtraction """ if opt.FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'CRNN': self.FeatureExtraction = CRNN() opt.output_channel=512 else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" if opt.SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) self.SequenceModeling_output = opt.hidden_size else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ self.CTC_Prediction = nn.Linear(self.SequenceModeling_output, opt.ctc_num_class) self.Attn_Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class)
def __init__(self, opt): super(Model_s, self).__init__() self.opt = opt self.stages = { 'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction } """ Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) else: print('No Transformation module specified') # model scatter self.Feat_Extraction = SCATTER(opt=opt, input_channel=opt.input_channel, lstm_layer=opt.LSTM_Layer, selective_layer=opt.Selective_Layer) """ FeatureExtraction """ """ Sequence modeling""" """ Prediction """ self.Prediction_atten = nn.ModuleList([ Attention(1024, opt.hidden_size, opt.num_class_atten) for _ in range(opt.Selective_Layer) ])
def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = { 'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction } """ Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ if opt.FeatureExtraction == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'DenseNet': self.FeatureExtraction = DenseNet(opt.input_channel, opt.output_channel) else: raise Exception('No FeatureExtraction module specified') if opt.FeatureExtraction == 'DenseNet': self.FeatureExtraction_output = 768 # DenseNet output channel is 768 else: self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 # self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" if opt.SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) self.SequenceModeling_output = opt.hidden_size else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if opt.Prediction == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) elif opt.Prediction == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) else: raise Exception('Prediction is neither CTC or Attn')
def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = { 'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction } """ 1. 空间转换层: Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( # Spatial: 空间 F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) # fiducial 基准 else: print('No Transformation module specified') """ 2. 特征抽取层: FeatureExtraction """ if opt.FeatureExtraction == 'VGG': # opt.input_channel: 输入图片的深度 depth, opt.output_channel: 512 self.FeatureExtraction = VGG_FeatureExtractor( opt.input_channel, opt.output_channel) # out: batch, 512, 1, 24 elif opt.FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor( opt.input_channel, opt.output_channel) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 opt.output_channel: default 512 # 512是特征长度, 24是序列长度 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 """ 3. 序列模型层: Sequence modeling""" if opt.SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM( self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), # 序列特征长度: 512, 隐藏: 256, 输出层: 256 BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) self.SequenceModeling_output = opt.hidden_size else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ 4. 预测层: Prediction """ if opt.Prediction == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) elif opt.Prediction == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) else: raise Exception('Prediction is neither CTC or Attn')
def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = { 'Transformation': opt.Transformation, 'FeatureExtraction': opt.FeatureExtraction, 'SequenceModeling': opt.SequenceModeling, 'Prediction': opt.Prediction } """ Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( opt.ft_config['trans'], F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ if opt.FeatureExtraction == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor( opt.ft_config['feat'], opt.input_channel, opt.output_channel) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" if opt.SequenceModeling == 'BiLSTM': self.SequenceModeling = BiLSTM(opt.ft_config['seq'], self.FeatureExtraction_output, opt.hidden_size) self.SequenceModeling_output = opt.hidden_size else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if opt.Prediction == 'CTC': self.Prediction = CTC_Prediction(opt.ft_config['pred'], self.SequenceModeling_output, opt.num_class) elif opt.Prediction == 'Attn': self.Prediction = Attention(opt.ft_config['pred'], self.SequenceModeling_output, opt.hidden_size, opt.num_class) else: raise Exception('Prediction is neither CTC or Attn') self.set_parameter_requires_grad() self.optimizers = self.configure_optimizers() self.make_statistics_params()
def __init__(self, FeatureExtraction='ResNet', PAD=False, Prediction='CTC', \ SequenceModeling='BiLSTM', Transformation='TPS', batch_max_length=25, \ batch_size=192, character='0123456789abcdefghijklmnopqrstuvwxyz', \ hidden_size=256, image_folder='data\\OCR-data-labelled\\OCR-data\\validate\\val', \ imgH=32, imgW=100, input_channel=1, num_class=37, num_fiducial=20, num_gpu=1, \ output_channel=512, rgb=False, saved_model='saved_models/TPS-ResNet-BiLSTM-CTC-Seed1111/best_accuracy.pth', \ sensitive=False, workers=4): super(Model, self).__init__() self.stages = { 'Trans': Transformation, 'Feat': FeatureExtraction, 'Seq': SequenceModeling, 'Pred': Prediction } """ Transformation """ if Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=num_fiducial, I_size=(imgH, imgW), I_r_size=(imgH, imgW), I_channel_num=input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ if FeatureExtraction == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor( input_channel, output_channel) elif FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( input_channel, output_channel) elif FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor( input_channel, output_channel) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = output_channel # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" if SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size), BidirectionalLSTM(hidden_size, hidden_size, hidden_size)) self.SequenceModeling_output = hidden_size else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if Prediction == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, num_class) elif Prediction == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class) else: raise Exception('Prediction is neither CTC or Attn')
def __init__(self, cfg): super(OCR_MODEL, self).__init__() self.cfg = cfg self.stages = { 'Trans': self.cfg.OCR.TRANSFORMATION, 'Feat': self.cfg.OCR.FEATURE_EXTRACTION, 'Seq': self.cfg.OCR.SEQUENCE_MODELING, 'Pred': self.cfg.OCR.PREDICTION } """ Transformation """ if self.cfg.OCR.TRANSFORMATION == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=self.cfg.OCR.NUM_FIDUCIAL, I_size=(cfg.BASE.IMG_H, cfg.BASE.IMG_W), I_r_size=(cfg.BASE.IMG_H, cfg.BASE.IMG_W), I_channel_num=self.cfg.OCR.INPUT_CHANNEL) else: print('No Transformation module specified') """ FeatureExtraction """ if self.cfg.OCR.FEATURE_EXTRACTION == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor( self.cfg.OCR.INPUT_CHANNEL, self.cfg.OCR.OUTPUT_CHANNEL) elif self.cfg.OCR.FEATURE_EXTRACTION == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( self.cfg.OCR.INPUT_CHANNEL, self.cfg.OCR.OUTPUT_CHANNEL) elif self.cfg.OCR.FEATURE_EXTRACTION == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor( self.cfg.OCR.INPUT_CHANNEL, self.cfg.OCR.OUTPUT_CHANNEL) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = self.cfg.OCR.OUTPUT_CHANNEL # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" if self.cfg.OCR.SEQUENCE_MODELING == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, self.cfg.OCR.HIDDEN_SIZE, self.cfg.OCR.HIDDEN_SIZE), BidirectionalLSTM(self.cfg.OCR.HIDDEN_SIZE, self.cfg.OCR.HIDDEN_SIZE, self.cfg.OCR.HIDDEN_SIZE)) self.SequenceModeling_output = self.cfg.OCR.HIDDEN_SIZE else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if self.cfg.OCR.PREDICTION == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, self.cfg.OCR.NUM_CLASS) elif self.cfg.OCR.PREDICTION == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, self.cfg.OCR.HIDDEN_SIZE, self.cfg.OCR.NUM_CLASS) else: raise Exception('Prediction is neither CTC or Attn')
def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.transform = TPS_SpatialTransformerNetwork(opt["fiducial_num"], opt["img_size"], opt["img_r_size"], opt["img_channel"]) self.features_extractor = ResNet_FeatureExtractor(opt["img_channel"]) self.GAP = nn.AdaptiveAvgPool2d((None, 1)) self.SequenceModeling = nn.Sequential( BidirectionalLSTM(512, opt["h_dim"], opt["h_dim"]), BidirectionalLSTM(opt["h_dim"], opt["h_dim"], opt["h_dim"])) self.Prediction = Attention(opt["h_dim"], opt["h_dim"], opt["out_dim"])
def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} """ Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ if opt.FeatureExtraction == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResNet_PyTorch': self.FeatureExtraction = models.resnet34(pretrained=True) #Freeze all the parameters of the feature extractor if opt.freeze_FeatureExtraction: print("Freezing the parameters of the resnet !") for param in self.FeatureExtraction.parameters(): param.requires_grad = False # delete the avgpool and the fc layer of the resnet self.FeatureExtraction = nn.Sequential(*(list(self.FeatureExtraction.children())[:-2])) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" if opt.SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) self.SequenceModeling_output = opt.hidden_size else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if opt.Prediction == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) elif opt.Prediction == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) else: raise Exception('Prediction is neither CTC or Attn')
def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = {'Trans': opt['Transformation'], 'Feat': opt['FeatureExtraction'], 'Seq': opt['SequenceModeling'], 'Pred': opt['Prediction']} """ Transformation """ if opt['Transformation'] == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=opt['num_fiducial'], I_size=(opt['imgH'], opt['imgW']), I_r_size=(opt['imgH'], opt['imgW']), I_channel_num=opt['input_channel']) #F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), batch_size=int(opt.batch_size/opt.num_gpu), I_channel_num=opt.input_channel) else: print('変換モジュールが指定されていません') """ FeatureExtraction """ if opt['FeatureExtraction'] == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor(opt['input_channel'], opt['output_channel']) elif opt['FeatureExtraction'] == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor(opt['input_channel'], opt['output_channel']) elif opt['FeatureExtraction'] == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor(opt['input_channel'], opt['output_channel']) else: raise Exception('FeatureExtractionモジュールが指定されていません') self.FeatureExtraction_output = opt['output_channel'] # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" if opt['SequenceModeling'] == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt['hidden_size'], opt['hidden_size']), BidirectionalLSTM(opt['hidden_size'], opt['hidden_size'], opt['hidden_size'])) self.SequenceModeling_output = opt['hidden_size'] else: print('SequenceModelingモジュールが指定されていません') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if opt['Prediction'] == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, opt['num_class']) elif opt['Prediction'] == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, opt['hidden_size'], opt['num_class']) else: raise Exception('予測がCTCでもAttnでもありませんでした。')
def __init__(self ,opt,num_class): super(Model, self).__init__() self.opt = opt """ Transformation TPS """ self.Transformation = TPS_SpatialTransformerNetwork( F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) """ FeatureExtraction ResNet""" self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling BiLSTM""" self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) self.SequenceModeling_output = opt.hidden_size """ Prediction Attn""" self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class ,num_class)
def __init__(self, imgW, imgH, num_class, batch_max_length): super(Model, self).__init__() num_fiducial = 20 input_channel = 1 output_channel = 512 hidden_size = 256 self.batch_max_length = batch_max_length self.stages = { 'Trans': "TPS", 'Feat': "ResNet", 'Seq': "BiLSTM", 'Pred': "Attn" } """ Transformation """ self.Transformation = TPS_SpatialTransformerNetwork( F=num_fiducial, I_size=(imgH, imgW), I_r_size=(imgH, imgW), I_channel_num=input_channel) """ FeatureExtraction """ self.FeatureExtraction = ResNet_FeatureExtractor( input_channel, output_channel) self.FeatureExtraction_output = output_channel # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size), BidirectionalLSTM(hidden_size, hidden_size, hidden_size)) self.SequenceModeling_output = hidden_size """ Prediction """ self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class)
def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = { 'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction } """ Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ if opt.FeatureExtraction == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor( opt.input_channel, opt.output_channel) self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 elif opt.FeatureExtraction == 'AsterRes': self.FeatureExtraction = ResNet_ASTER2(opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResnetFpn': self.FeatureExtraction = ResNet_FPN() else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 """ Sequence modeling""" if opt.SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) self.SequenceModeling_output = opt.hidden_size elif opt.SequenceModeling == 'Bert': cfg = Config() cfg.dim = opt.output_channel cfg.dim_c = opt.output_channel # 降维减少计算量 cfg.p_dim = opt.position_dim # 一张图片cnn编码之后的特征序列长度 cfg.max_vocab_size = opt.batch_max_length + 1 # 一张图片中最多的文字个数, +1 for EOS cfg.len_alphabet = opt.alphabet_size # 文字的类别个数 self.SequenceModeling = Bert_Ocr(cfg) elif opt.SequenceModeling == 'SRN': self.SequenceModeling = Transforme_Encoder() else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if opt.Prediction == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) elif opt.Prediction == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) elif opt.Prediction == 'Bert_pred': pass elif opt.Prediction == 'SRN': self.Prediction = SRN_Decoder(n_position=opt.n_position) else: raise Exception('Prediction is neither CTC or Attn')
def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = { 'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction } """ Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ if opt.FeatureExtraction == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResNet' or 'SEResNet' == opt.FeatureExtraction: self.FeatureExtraction = ResNet_FeatureExtractor( opt.input_channel, opt.output_channel, opt) elif 'SEResNetXt' in opt.FeatureExtraction: opt.output_channel = 2048 self.FeatureExtraction = ResNet_FeatureExtractor( opt.input_channel, opt.output_channel, opt) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 if opt.dropout > 0: self.dropout = nn.Dropout(opt.dropout) else: self.dropout = None """ Sequence modeling""" if opt.SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size, opt.rnnlayers, opt.rnndropout), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size, opt.rnnlayers, opt.rnndropout)) self.SequenceModeling_output = opt.hidden_size elif opt.SequenceModeling == 'BiGRU': self.SequenceModeling = nn.Sequential( BidirectionalGRU(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size, opt.rnnlayers, opt.rnndropout), BidirectionalGRU(opt.hidden_size, opt.hidden_size, opt.hidden_size, opt.rnnlayers, opt.rnndropout)) self.SequenceModeling_output = opt.hidden_size else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if opt.Prediction == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class + 1) elif opt.Prediction == 'Attn': self.Prediction = Attention(opt, self.SequenceModeling_output, opt.hidden_size, opt.num_class + 2) elif opt.Prediction == 'CTC_Attn': self.Prediction_ctc = nn.Linear(self.SequenceModeling_output, opt.num_class + 1) self.Prediction_attn = Attention(opt, self.SequenceModeling_output, opt.hidden_size, opt.num_class + 2) else: raise Exception('Prediction is neither CTC or Attn')
def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = { 'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction } """ Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ map_backbon_to_len_sequence = {'VGG': 24, 'RCNN': 26, 'ResNet': 26} self.len_sequence = map_backbon_to_len_sequence[opt.FeatureExtraction] if opt.FeatureExtraction == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor( opt.input_channel, opt.output_channel) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 self.GraphConvolution_output = opt.output_channel_GCN self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" if opt.SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) self.SequenceModeling_output = opt.hidden_size else: if opt.SequenceModeling == 'GCN-BiLSTM': self.SequenceModeling = nn.Sequential( GraphConvolution(opt.batch_size, self.len_sequence, self.FeatureExtraction_output, self.GraphConvolution_output, bias=False, scale_factor=5), BidirectionalLSTM(self.GraphConvolution_output, opt.hidden_size, opt.hidden_size), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) self.SequenceModeling_output = opt.hidden_size else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if opt.Prediction == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) elif opt.Prediction == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) else: raise Exception('Prediction is neither CTC or Attn')
def __init__(self, imgH, imgW, input_channel, output_channel, hidden_size, num_class, batch_max_length, Transformation="None", FeatureExtraction="VGG", SequenceModeling="BiLSTM", Prediction="CTC", F=20): super(Model, self).__init__() self.stages = { 'Trans': Transformation, 'Feat': FeatureExtraction, 'Seq': SequenceModeling, 'Pred': Prediction } """ Transformation """ if Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=F, I_size=(imgH, imgW), I_r_size=(imgH, imgW), I_channel_num=input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ if FeatureExtraction == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor( input_channel, output_channel) elif FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( input_channel, output_channel) elif FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor( input_channel, output_channel) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = output_channel # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" if SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size), BidirectionalLSTM(hidden_size, hidden_size, hidden_size)) self.SequenceModeling_output = hidden_size else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if Prediction == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, num_class) elif Prediction == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class) else: raise Exception('Prediction is neither CTC or Attn')
def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = { 'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction } """ Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ map_backbon_to_len_sequence = {'VGG': 24, 'RCNN': 26, 'ResNet': 26} self.len_sequence = map_backbon_to_len_sequence[opt.FeatureExtraction] if opt.FeatureExtraction == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor( opt.input_channel, opt.output_channel) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 self.GraphConvolution_output = opt.output_channel_GCN self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" # if opt.SequenceModeling == 'BiLSTM': # self.SequenceModeling = nn.Sequential( # BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), # BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) # self.SequenceModeling_output = opt.hidden_size # else : # if opt.SequenceModeling == 'GCN-BiLSTM': # self.SequenceModeling = nn.Sequential( # GraphConvolution(opt.batch_size, self.len_sequence, self.FeatureExtraction_output, self.GraphConvolution_output, bias = False, scale_factor = 2.0,dropout = 0.0), # BidirectionalLSTM(self.GraphConvolution_output, opt.hidden_size, opt.hidden_size), # BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) # self.SequenceModeling_output = opt.hidden_size # else: # print('No SequenceModeling module specified') # self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ # if opt.guide_training : self.SequenceModeling_CTC = nn.Sequential( WeightDropBiLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size, numclass=opt.num_class_ctc, dropouti=0.05, wdrop=0.2, dropouto=0.05), # GraphConvolution(opt.batch_size, self.len_sequence, self.FeatureExtraction_output, self.GraphConvolution_output, bias = False,scale_factor = 2.0,dropout = 0.0,isnormalize = True), # BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size,dropouti = 0.1, wdrop = 0.2, dropouto = 0.0), # BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size,is_last_blstm = True, numclass = opt.num_class_ctc,dropouti = 0.1, wdrop = 0.2, dropouto = 0.1), # BidirectionalLSTM(self.GraphConvolution_output, opt.hidden_size, opt.hidden_size,is_last_blstm = True, numclass = opt.num_class_ctc), # nn.Dropout(p=0.2), # BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size), ) self.SequenceModeling_Attn = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size), ) # self.CTC = nn.Linear(opt.hidden_size, opt.num_class_ctc) self.Attention = Attention(opt.hidden_size, opt.hidden_size, opt.num_class_attn)
def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = { 'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction } """ Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ if opt.FeatureExtraction == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor( opt.input_channel, opt.output_channel) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" if opt.SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) self.SequenceModeling_output = opt.hidden_size elif opt.SequenceModeling == 'Transformer': self.SequenceModeling = LanguageTransformer( vocab_size=opt.num_class, input_size=self.FeatureExtraction_output, d_model=opt.d_model, nhead=opt.nhead, num_encoder_layers=opt.num_encoder_layers, num_decoder_layers=opt.num_decoder_layers, dim_feedforward=opt.dim_feedforward, max_seq_length=opt.max_seq_length, pos_dropout=opt.pos_dropout, trans_dropout=opt.trans_dropout, ) else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if opt.Prediction == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) elif opt.Prediction == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) elif opt.Prediction == 'None': self.Prediction = noop else: raise Exception('Prediction is neither CTC or Attn')
def __init__(self, cfg): super(Model, self).__init__() self.cfg = cfg self.stages = {'Trans': cfg['model']['transform'], 'Feat': cfg['model']['extraction'], 'Seq': cfg['model']['sequence'], 'Pred': cfg['model']['prediction']} """ Transformation """ if cfg['model']['transform'] == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=cfg['transform']['num_fiducial'], I_size=(cfg['dataset']['imgH'],cfg['dataset']['imgW']), I_r_size=(cfg['dataset']['imgH'], cfg['dataset']['imgW']), I_channel_num=cfg['model']['input_channel']) print ("Transformation moduls : {}".format(cfg['model']['transform'])) else: print('No Transformation module specified') """ FeatureExtraction """ if cfg['model']['extraction'] == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor(cfg['model']['input_channel'], cfg['model']['output_channel']) elif cfg['model']['extraction'] == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor(cfg['model']['input_channel'], cfg['model']['output_channel']) elif cfg['model']['extraction'] == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor(cfg['model']['input_channel'], cfg['model']['output_channel']) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = cfg['model']['output_channel'] # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 print ('Feature extractor : {}'.format(cfg['model']['extraction'])) """ Sequence modeling""" if cfg['model']['sequence'] == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, cfg['model']['hidden_size'], cfg['model']['hidden_size']), BidirectionalLSTM(cfg['model']['hidden_size'], cfg['model']['hidden_size'], cfg['model']['hidden_size'])) self.SequenceModeling_output = cfg['model']['hidden_size'] # SequenceModeling : Transformer elif cfg['model']['sequence'] == 'Transformer': self.SequenceModeling = Transformer( d_model=self.FeatureExtraction_output, nhead=2, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=cfg['model']['hidden_size'], dropout=0.1, activation='relu') print('SequenceModeling: Transformer initialized.') self.SequenceModeling_output = self.FeatureExtraction_output # 입력의 차원과 같은 차원으로 출력 됨 else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output print('Sequence modeling : {}'.format(cfg['model']['sequence'])) """ Prediction """ if cfg['model']['prediction'] == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, cfg['training']['num_class']) elif cfg['model']['prediction'] == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, cfg['model']['hidden_size'], cfg['training']['num_class']) elif cfg['model']['prediction'] == 'Transformer': self.Prediction = nn.Linear(self.SequenceModeling_output, cfg['training']['num_class']) else: raise Exception('Prediction should be in [CTC | Attn | Transformer]') print ("Prediction : {}".format(cfg['model']['prediction']))
def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = { 'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction } """ Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ if opt.FeatureExtraction == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor( opt.input_channel, opt.output_channel) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" if opt.SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) self.SequenceModeling_output = opt.hidden_size else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if opt.Prediction == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) elif opt.Prediction == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) elif opt.Prediction == 'Transformer': device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') OUTPUT_DIM = opt.num_class HID_DIM = self.SequenceModeling_output ENC_LAYERS = 3 DEC_LAYERS = 3 ENC_HEADS = 8 DEC_HEADS = 8 ENC_PF_DIM = 512 DEC_PF_DIM = 512 ENC_DROPOUT = 0.1 DEC_DROPOUT = 0.1 enc = Encoder(HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device) dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device) TRG_PAD_IDX = 2 #TRG.vocab.stoi[TRG.pad_token] self.Prediction = Seq2Seq(enc, dec, TRG_PAD_IDX, device).to(device) self.Prediction.apply(self.initialize_weights) print("use transformer") else: raise Exception('Prediction is neither CTC or Attn')
def __init__(self, opt, SelfSL_layer=False): super(Model, self).__init__() self.opt = opt self.stages = { "Trans": opt.Transformation, "Feat": opt.FeatureExtraction, "Seq": opt.SequenceModeling, "Pred": opt.Prediction, } """ Transformation """ if opt.Transformation == "TPS": self.Transformation = TPS_SpatialTransformerNetwork( F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel, ) else: print("No Transformation module specified") """ FeatureExtraction """ if opt.FeatureExtraction == "VGG": self.FeatureExtraction = VGG_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == "RCNN": self.FeatureExtraction = RCNN_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == "ResNet": self.FeatureExtraction = ResNet_FeatureExtractor( opt.input_channel, opt.output_channel) else: raise Exception("No FeatureExtraction module specified") self.FeatureExtraction_output = opt.output_channel self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 if not SelfSL_layer: # for STR """Sequence modeling""" if opt.SequenceModeling == "BiLSTM": self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size), ) self.SequenceModeling_output = opt.hidden_size else: print("No SequenceModeling module specified") self.SequenceModeling_output = self.FeatureExtraction_output if not SelfSL_layer: # for STR. """Prediction""" if opt.Prediction == "CTC": self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) elif opt.Prediction == "Attn": self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) else: raise Exception("Prediction is neither CTC or Attn") else: """for self-supervised learning (SelfSL)""" self.AdaptiveAvgPool_2 = nn.AdaptiveAvgPool2d( (None, 1)) # make width -> 1 if SelfSL_layer == "CNN": self.SelfSL_FFN_input = self.FeatureExtraction_output if "RotNet" in self.opt.self: self.SelfSL = nn.Linear(self.SelfSL_FFN_input, 4) # 4 = [0, 90, 180, 270] degrees elif "MoCo" in self.opt.self: self.SelfSL = nn.Linear(self.SelfSL_FFN_input, 128) # 128 is used for MoCo paper.
class Model(nn.Module): def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.stages = { 'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction } """ Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ if opt.FeatureExtraction == 'VGG': self.FeatureExtraction = VGG_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'RCNN': self.FeatureExtraction = RCNN_FeatureExtractor( opt.input_channel, opt.output_channel) elif opt.FeatureExtraction == 'ResNet': self.FeatureExtraction = ResNet_FeatureExtractor( opt.input_channel, opt.output_channel) else: raise Exception('No FeatureExtraction module specified') self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( (None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" if opt.SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) self.SequenceModeling_output = opt.hidden_size else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ if opt.Prediction == 'CTC': self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) elif opt.Prediction == 'Attn': self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) elif opt.Prediction == 'Transformer': device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') OUTPUT_DIM = opt.num_class HID_DIM = self.SequenceModeling_output ENC_LAYERS = 3 DEC_LAYERS = 3 ENC_HEADS = 8 DEC_HEADS = 8 ENC_PF_DIM = 512 DEC_PF_DIM = 512 ENC_DROPOUT = 0.1 DEC_DROPOUT = 0.1 enc = Encoder(HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device) dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device) TRG_PAD_IDX = 2 #TRG.vocab.stoi[TRG.pad_token] self.Prediction = Seq2Seq(enc, dec, TRG_PAD_IDX, device).to(device) self.Prediction.apply(self.initialize_weights) print("use transformer") else: raise Exception('Prediction is neither CTC or Attn') def initialize_weights(self, m): if hasattr(m, 'weight') and m.weight.dim() > 1: nn.init.xavier_uniform_(m.weight.data) def forward(self, input, text, is_train=True): """ Transformation stage """ if not self.stages['Trans'] == "None": input = self.Transformation(input) """ Feature extraction stage """ visual_feature = self.FeatureExtraction(input) visual_feature = self.AdaptiveAvgPool( visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] visual_feature = visual_feature.squeeze(3) """ Sequence modeling stage """ if self.stages['Seq'] == 'BiLSTM': contextual_feature = self.SequenceModeling(visual_feature) else: contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM """ Prediction stage """ if self.stages['Pred'] == 'CTC': prediction = self.Prediction(contextual_feature.contiguous()) elif self.stages['Pred'] == 'Transformer': prediction = self.Prediction(contextual_feature.contiguous(), text, is_train) else: prediction = self.Prediction( contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length) return prediction