コード例 #1
0
    def __init__(self, num_class):
        self.inplanes = 128
        super(SpatialOCRNetasDec, self).__init__()
        self.num_classes=num_class
        in_channels = [1024, 2048]
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(in_channels[1], 512, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        from models.ocr_modules.spatial_ocr_block import SpatialGather_Module, SpatialOCR_Module
        self.spatial_context_head = SpatialGather_Module(self.num_classes)
        self.spatial_ocr_head = SpatialOCR_Module(in_channels=512,
                                                  key_channels=256,
                                                  out_channels=512,
                                                  scale=1,
                                                  dropout=0.05
                                                  )

    #    self.head = nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True)
        self.dsn_head = nn.Sequential(
            nn.Conv2d(in_channels[0], 512, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.05),
            nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True)
            )
コード例 #2
0
    def __init__(self, net_enc, crit, args, deep_sup_scale=None):

        super(ClipOCRNet, self).__init__()

        self.args = args

        if self.args.use_memory:
            self.memory = []
        self.crit = crit
        self.deep_sup_scale = deep_sup_scale
        self.encoder = net_enc
        self.inplanes = 128
        self.num_classes = args.num_class
        in_channels = [1024, 2048]
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(in_channels[1], 512, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(512), nn.ReLU(inplace=True))

        self.spatial_context_head = SpatialTemporalGather_Module(
            self.num_classes)
        self.spatial_ocr_head = SpatialOCR_Module(in_channels=512,
                                                  key_channels=256,
                                                  out_channels=512,
                                                  scale=1,
                                                  dropout=0.05)

        self.head = nn.Conv2d(512,
                              self.num_classes,
                              kernel_size=1,
                              stride=1,
                              padding=0,
                              bias=True)
        self.dsn_head = nn.Sequential(
            nn.Conv2d(in_channels[0], 512, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(512), nn.ReLU(inplace=True), nn.Dropout2d(0.05),
            nn.Conv2d(512,
                      self.num_classes,
                      kernel_size=1,
                      stride=1,
                      padding=0,
                      bias=True))