예제 #1
0
    def __init__(self, opt, device):
        super(STR, self).__init__()
        self.opt = opt
        
#         Trans
#         self.Trans = Trans.TPS_SpatialTransformerNetwork(F = opt.num_fiducial,
#                                                   i_size = (opt.imgH, opt.imgW), 
#                                                   i_r_size= (opt.imgH, opt.imgW), 
#                                                   i_channel_num=opt.input_channel,
#                                                         device = device)
        #Extract
        if self.opt.extract =='RCNN':
            self.Extract = self.Extract = Extract.RCNN_extractor(opt.input_channel, opt.output_channel)
        elif 'efficientnet' in self.opt.extract :
            self.Extract = Extract.EfficientNet(opt)
        elif 'resnet' in self.opt.extract :
            self.Extract = Extract.ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
        else:
            raise print('invalid extract model name!')

#         self.Extract = Extract.RCNN_extractor(opt.input_channel, opt.output_channel)
#         self.Extract = Extract.ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
        self.FeatureExtraction_output = opt.output_channel # (imgH/16 -1 )* 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None,1)) # imgH/16-1   ->  1
            
        # Sequence
        self.Seq = nn.Sequential(
            BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size,  opt.hidden_size),
#             BidirectionalLSTM(1536, opt.hidden_size,  opt.hidden_size),
            BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
        self.Seq_output = opt.hidden_size
        
        #Pred
        self.Pred = Pred.Attention(self.Seq_output, opt.hidden_size, opt.num_classes, device=device)
예제 #2
0
    def __init__(self, opt, device):
        super(model, self).__init__()
        self.opt = opt

        #Trans
        self.Trans = Trans.TPS_SpatialTransformerNetwork(
            F=opt.num_fiducial,
            i_size=(opt.imgH, opt.imgW),
            i_r_size=(opt.imgH, opt.imgW),
            i_channel_num=opt.input_channel,
            device=device)
        #Extract
        if self.opt.extract == 'RCNN':
            self.Extract = self.Extract = Extract.RCNN_extractor(
                opt.input_channel, opt.output_channel)
        elif 'efficientnet' in self.opt.extract:
            self.Extract = Extract.EfficientNet(opt)
        elif 'resnet' in self.opt.extract:
            self.Extract = Extract.ResNet_FeatureExtractor(
                opt.input_channel, opt.output_channel)
        else:
            raise print('invalid extract model name!')

#         self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None,1)) # imgH/16-1   ->  1

#  Position aware module
        self.PAM = PositionEnhancement.PositionAwareModule(
            opt.output_channel, opt.hidden_size, opt.output_channel, 2)

        self.PAttnM_bot = PositionEnhancement.AttnModule(
            opt, opt.hidden_size, opt.bot_n_cls, device)
        self.PAttnM_mid = PositionEnhancement.AttnModule(
            opt, opt.hidden_size, opt.mid_n_cls, device)
        self.PAttnM_top = PositionEnhancement.AttnModule(
            opt, opt.hidden_size, opt.top_n_cls, device)

        # Hybrid branch
        self.Hybrid_bot = Hybrid.HybridBranch(opt.output_channel,
                                              opt.batch_max_length + 1,
                                              opt.bot_n_cls, device)
        self.Hybrid_mid = Hybrid.HybridBranch(opt.output_channel,
                                              opt.batch_max_length + 1,
                                              opt.mid_n_cls, device)
        self.Hybrid_top = Hybrid.HybridBranch(opt.output_channel,
                                              opt.batch_max_length + 1,
                                              opt.top_n_cls, device)

        #         # Dynamically fusing module
        self.Dynamic_fuser_top = PositionEnhancement.DynamicallyFusingModule(
            opt.top_n_cls)
        self.Dynamic_fuser_mid = PositionEnhancement.DynamicallyFusingModule(
            opt.mid_n_cls)
        self.Dynamic_fuser_bot = PositionEnhancement.DynamicallyFusingModule(
            opt.bot_n_cls)
예제 #3
0
    def __init__(self, opt, device):
        super(SCATTER, self).__init__()
        self.opt = opt
        
        #Trans
        self.Trans = Trans.TPS_SpatialTransformerNetwork(F = opt.num_fiducial, i_size = (opt.imgH, opt.imgW), 
                                                  i_r_size= (opt.imgH, opt.imgW), i_channel_num=opt.input_channel, device = device)
        
        #Extract
        if self.opt.extract =='RCNN':
            self.Extract = self.Extract = Extract.RCNN_extractor(opt.input_channel, opt.output_channel)
        elif 'efficientnet' in self.opt.extract :
            self.Extract = Extract.EfficientNet(opt)
        else:
            raise print('invalid extract model name!')
        #         self.Extract = Extract.ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
        self.FeatureExtraction_output = opt.output_channel # (imgH/16 -1 )* 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None,1)) # imgH/16-1   ->  1
            
        # VISUAL FEATURES 
        self.VFR = VFR.Visual_Features_Refinement(kernel_size = (3,1), num_classes = opt.num_classes, 
                                              in_channels = self.FeatureExtraction_output, out_channels=1, stride=1)
        
        # CTC DECODER
        self.CTC = CTC.CTC_decoder(opt.output_channel, opt.output_channel, opt.num_classes, device)
            
        # Selective Contextual Refinement Block
#         self.SCR_1 = SCR.Selective_Contextual_refinement_block(input_size = self.FeatureExtraction_output, 
#                                                          hidden_size = int(self.FeatureExtraction_output/2),
#                                                         output_size = self.FeatureExtraction_output,
#                                                         num_classes = opt.num_classes, decoder_fix = False, device = device, 
#                                                         batch_max_length = opt.batch_max_length)
        
#         self.SCR_2 = SCR.Selective_Contextual_refinement_block(input_size = self.FeatureExtraction_output, 
#                                                          hidden_size = int(self.FeatureExtraction_output/2),
#                                                         output_size = self.FeatureExtraction_output,
#                                                         num_classes = opt.num_classes, decoder_fix = False, device = device,
#                                                         batch_max_length = opt.batch_max_length)
        
#         self.SCR_3 = SCR.Selective_Contextual_refinement_block(input_size = self.FeatureExtraction_output, 
#                                                          hidden_size = int(self.FeatureExtraction_output/2),
#                                                         output_size = self.FeatureExtraction_output,
#                                                         num_classes = opt.num_classes, decoder_fix = True, device = device,
#                                                         batch_max_length = opt.batch_max_length)

        self.SCR = SCR.SCR_Blocks(input_size = self.FeatureExtraction_output, 
                                                         hidden_size = int(self.FeatureExtraction_output/2),
                                                        output_size = self.FeatureExtraction_output,
                                                        num_classes = opt.num_classes, device = device,
                                                        batch_max_length = opt.batch_max_length, 
                                                        n_blocks=opt.scr_n_blocks)
    def __init__(self, opt, device):
        super(STR, self).__init__()
        self.opt = opt
        
        #Trans
        self.Trans = Trans.TPS_SpatialTransformerNetwork(F = opt.num_fiducial,
                                                  i_size = (opt.imgH, opt.imgW), 
                                                  i_r_size= (opt.imgH, opt.imgW), 
                                                  i_channel_num=opt.input_channel,
                                                        device = device)
        #Extract
        if self.opt.extract =='RCNN':
            self.Extract = self.Extract = Extract.RCNN_extractor(opt.input_channel, opt.output_channel)
        elif 'efficientnet' in self.opt.extract :
            self.Extract = Extract.EfficientNet(opt)
        else:
            raise print('invalid extract model name!')
            
#         self.Extract = Extract.ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
        self.FeatureExtraction_output = opt.output_channel # (imgH/16 -1 )* 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None,1)) # imgH/16-1   ->  1

            
        # Sequence
        self.Seq = nn.Sequential(
            BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size,  opt.hidden_size),
            BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
        self.Seq_output = opt.hidden_size
        
        
        #Pred
        if opt.pred =='arcface':
            print('using ArcFace Loss') 
            self.Pred_bot = Pred_jamo_arcface.Attention(self.Seq_output, opt.hidden_size, opt.bottom_n_cls, device=device)
            self.Pred_mid = Pred_jamo_arcface.Attention_mid(self.Seq_output, opt.hidden_size, opt.middle_n_cls, opt.bottom_n_cls,  device=device)
            self.Pred_top = Pred_jamo_arcface.Attention_top(self.Seq_output, opt.hidden_size, opt.top_n_cls, opt.middle_n_cls, opt.bottom_n_cls, device=device)
            
        else :
            self.Pred_bot = Pred_jamo.Attention(self.Seq_output, opt.hidden_size, opt.bottom_n_cls, device=device)
            self.Pred_mid = Pred_jamo.Attention_mid(self.Seq_output, opt.hidden_size, opt.middle_n_cls, opt.bottom_n_cls,  device=device)
            self.Pred_top = Pred_jamo.Attention_top(self.Seq_output, opt.hidden_size, opt.top_n_cls, opt.middle_n_cls, opt.bottom_n_cls, device=device)
    def __init__(self, opt, device):
        super(STR, self).__init__()
        self.opt = opt

        #Trans
        self.Trans = Trans.TPS_SpatialTransformerNetwork(
            F=opt.num_fiducial,
            i_size=(opt.imgH, opt.imgW),
            i_r_size=(opt.imgH, opt.imgW),
            i_channel_num=opt.input_channel,
            device=device)
        #Extract
        if self.opt.extract == 'RCNN':
            self.Extract = self.Extract = Extract.RCNN_extractor(
                opt.input_channel, opt.output_channel)
        elif 'efficientnet' in self.opt.extract:
            self.Extract = Extract.EfficientNet(opt)
        else:
            raise print('invalid extract model name!')

#         self.Extract = Extract.ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
        self.FeatureExtraction_output = opt.output_channel  # (imgH/16 -1 )* 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
            (None, 1))  # imgH/16-1   ->  1

        # Sequence
        self.Seq = nn.Sequential(
            BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size,
                              opt.hidden_size),
            BidirectionalLSTM(opt.hidden_size, opt.hidden_size,
                              opt.hidden_size))
        self.Seq_output = opt.hidden_size

        #Pred (Hybrid branch)

        self.Pred_bot = Pred_jamo_position.Attention(self.Seq_output,
                                                     opt.hidden_size,
                                                     opt.bottom_n_cls,
                                                     device=device)
        self.Pred_mid = Pred_jamo_position.Attention_mid(self.Seq_output,
                                                         opt.hidden_size,
                                                         opt.middle_n_cls,
                                                         opt.bottom_n_cls,
                                                         device=device)
        self.Pred_top = Pred_jamo_position.Attention_top(self.Seq_output,
                                                         opt.hidden_size,
                                                         opt.top_n_cls,
                                                         opt.middle_n_cls,
                                                         opt.bottom_n_cls,
                                                         device=device)

        # Position enhancement module

        self.Position_attn_top = Pred_jamo_position.PositionAttnModule(
            opt, opt.top_n_cls, device)
        self.Position_attn_mid = Pred_jamo_position.PositionAttnModule(
            opt, opt.middle_n_cls, device)
        self.Position_attn_bot = Pred_jamo_position.PositionAttnModule(
            opt, opt.bottom_n_cls, device)

        # Dynamically fusing module

        self.Dynamic_fuser_top = Pred_jamo_position.DynamicallyFusingModule(
            opt.top_n_cls)
        self.Dynamic_fuser_mid = Pred_jamo_position.DynamicallyFusingModule(
            opt.middle_n_cls)
        self.Dynamic_fuser_bot = Pred_jamo_position.DynamicallyFusingModule(
            opt.bottom_n_cls)