コード例 #1
0
    def __init__(self, args, num_classes=21):
        super(DeepLab, self).__init__()
        self.args = args
        output_stride = args.out_stride

        if args.backbone == 'drn':
            output_stride = 8
        if args.backbone.split('-')[0] == 'efficientnet':
            output_stride = 32

        if args.norm == 'gn': norm = gn
        elif args.norm == 'bn': norm = bn
        elif args.norm == 'syncbn': norm = syncbn
        else:
            print(args.norm, "normalization is not implemented")
            raise NotImplementedError

        self.backbone = build_backbone(args)
        self.aspp = build_aspp(args.backbone, args.out_stride, norm)
        self.decoder = build_decoder(num_classes, args.backbone, norm)

        self.classifier = nn.Linear(300, num_classes)

        if self.args.freeze_bn:
            self.freeze_bn()
コード例 #2
0
    def __init__(self,
                 backbone='resnet',
                 output_stride=16,
                 num_classes=21,
                 sync_bn=True,
                 freeze_bn=False,
                 use_iou=True):
        super(DeepLab, self).__init__()
        self.use_iou = use_iou
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)
        if self.use_iou:
            self.maskiou = build_maskiou(num_classes, BatchNorm)

        if freeze_bn:
            self.freeze_bn()
コード例 #3
0
    def __init__(self,
                 backbone='resnet',
                 output_stride=16,
                 num_classes=21,
                 sync_bn=True,
                 freeze_bn=False):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)
        self.sr_decoder = build_sr_decoder(num_classes, backbone, BatchNorm)
        self.pointwise = torch.nn.Sequential(
            torch.nn.Conv2d(num_classes, 3, 1),
            torch.nn.BatchNorm2d(3),  #添加了BN层
            torch.nn.ReLU(inplace=True))

        self.up_sr_1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
        self.up_edsr_1 = EDSRConv(64, 64)
        self.up_sr_2 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.up_edsr_2 = EDSRConv(32, 32)
        self.up_sr_3 = nn.ConvTranspose2d(32, 16, 2, stride=2)
        self.up_edsr_3 = EDSRConv(16, 16)
        self.up_conv_last = nn.Conv2d(16, 3, 1)

        self.freeze_bn = freeze_bn
コード例 #4
0
    def __init__(self,
                 backbone='resnet',
                 output_stride=16,
                 num_classes=19,
                 sync_bn=True,
                 freeze_bn=False,
                 args=None,
                 separate=False):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm,
                                       args)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm, args,
                               separate)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm, args,
                                     separate)

        if freeze_bn:
            self.freeze_bn()
コード例 #5
0
    def __init__(self,
                 backbone='resnet_multiscale',
                 output_stride=16,
                 num_classes=21,
                 sync_bn=True,
                 freeze_bn=False):
        super(DeepLabCA, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)

        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)
        self.avg_pool = nn.AdaptiveAvgPool2d(32)
        # self.se=RCAB(2048+1024+512+256+256,1,16)
        self.ca = CAM_Module()
        in_channels = 2048 + 1024 + 512 + 256 + 256
        inter_channels = in_channels // 4
        # self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
        #                            BatchNorm(inter_channels),
        #                            nn.ReLU())
        # self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
        #                            BatchNorm(inter_channels),
        #                            nn.ReLU())

        # self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(inter_channels, num_classes, 1))

        if freeze_bn:
            self.freeze_bn()
コード例 #6
0
    def __init__(
        self,
        Norm,
        backbone="resnet",
        output_stride=16,
        num_classes=3,
        freeze_bn=False,
        abn=False,
    ):
        super(DeepLabv3, self).__init__()
        self.abn = abn

        if backbone == "drn":
            output_stride = 8

        if Norm == "gn":
            norm = gn
        elif Norm == "bn":
            norm = bn
        elif Norm == "syncbn":
            norm = syncbn

        self.backbone = build_backbone(backbone,
                                       output_stride,
                                       Norm,
                                       dec=False,
                                       abn=abn)
        self.aspp = build_aspp(backbone, output_stride, norm, dec=False)
        if freeze_bn:
            self.freeze_bn()
コード例 #7
0
    def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
                 sync_bn=True, freeze_bn=False):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)#输出三个  x, feature_map,low_level_feat
        #self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        #self.dspp=build_dspp(backbone,output_stride,BatchNorm,modulation=False,adaptive_d= False)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)
        #self.baseline=nn.Sequential(nn.Conv2d(in_channels=2048,out_channels=256,kernel_size=1,stride=1),BatchNorm(256))
        #self.denseaspp = build_DenseASPP(BatchNorm)
        self.stack_resudial = build_stack_resudial_conv(backbone,output_stride,BatchNorm=BatchNorm,modulation=False,adaptive_d=False,deform=True)
        #self.stack = build_stack_conv(backbone,output_stride,modulation=True,adaptive_d=False,BatchNorm=BatchNorm,deform=True)
        #self.densedspp = build_densedspp()
        #self.densedspp_v3 =build_densedspp_v3(modulation=False,adaptive_d = False)
        #self.decoder_gau =build_decoder_gau(BatchNorm)
        #self.fpa = build_fpa(2048)
        #self.conv3x3_dspp_decoder = nn.Conv2d(2048,256,3)

        if freeze_bn:
            self.freeze_bn()
コード例 #8
0
ファイル: networks.py プロジェクト: DotWang/DFC2020
    def __init__(self, args, backbone='resnet', output_stride=16, num_classes=4,
                 sync_bn=False, freeze_bn=False, depth=50):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8
        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.args = args
        if self.args.oly_s1 and not self.args.oly_s2:
            in_channel = 2
            pretrn=False
        elif not self.args.oly_s1 and self.args.oly_s2 and not self.args.rgb:
            in_channel = 10
            pretrn = False
        elif not self.args.oly_s1 and self.args.oly_s2 and self.args.rgb:
            in_channel = 3
            pretrn = True
        elif not self.args.oly_s1 and not self.args.oly_s2 and not self.args.rgb:
            in_channel = 12
            pretrn = False
        elif not self.args.oly_s1 and not self.args.oly_s2 and self.args.rgb:
            in_channel = 5
            pretrn = False
        else:
            raise NotImplementedError
        self.backbone = build_backbone(backbone, in_channel, output_stride, BatchNorm, depth, pretrn)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)

        if freeze_bn:
            self.freeze_bn()
コード例 #9
0
    def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
                 sync_bn=False, freeze_bn=False):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8
        if backbone == 'resnet':
            link_in = 1024
            link_out = 1024
        elif backbone == 'mobilenet':
            link_in = 64
            link_out = 64

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.link_conv = nn.Sequential(nn.Conv2d(link_in, link_out, kernel_size=1, stride=1, padding=0, bias=False))
        self.last_conv = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
                                        BatchNorm(64),
                                        nn.ReLU(),
                                        nn.Dropout(0.1),
                                        nn.Conv2d(64, num_classes, kernel_size=1, stride=1))

        self._init_weight()
        if freeze_bn:
            self.freeze_bn()
コード例 #10
0
    def __init__(self,
                 args,
                 num_classes=21,
                 freeze_bn=False,
                 abn=False,
                 deep_dec=True):
        super(DeepLab, self).__init__()
        self.args = args
        self.abn = abn
        self.deep_dec = deep_dec  # if True, it deeplabv3+, otherwise, deeplabv3
        output_stride = args.out_stride

        if args.backbone == "drn":
            output_stride = 8
        if args.backbone.split("-")[0] == "efficientnet":
            output_stride = 32

        if args.norm == "gn":
            norm = gn
        elif args.norm == "bn":
            norm = bn
        elif args.norm == "syncbn":
            norm = syncbn
        else:
            print(args.norm, "normalization is not implemented")
            raise NotImplementedError

        self.backbone = build_backbone(args)
        self.aspp = build_aspp(args.backbone, args.out_stride, norm)
        if self.deep_dec:
            self.decoder = build_decoder(num_classes, args.backbone, norm)

        if freeze_bn:
            self.freeze_bn()
コード例 #11
0
    def __init__(self,
                 args,
                 backbone='resnet',
                 output_stride=16,
                 num_classes=21,
                 sync_bn=True,
                 freeze_bn=False):
        super(DeepLab, self).__init__()
        self.args = args
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        if self.args.use_kinematic == False:
            self.decoder = build_decoder(num_classes, backbone, BatchNorm)
        else:
            self.decoder = build_decoder_kinematic(backbone, BatchNorm)
            self.kinematic_layer = build_kinematic_graph(BatchNorm)

        self.freeze_bn = freeze_bn
コード例 #12
0
    def __init__(self,
                 backbone='resnet',
                 output_stride=16,
                 sync_bn=True,
                 freeze_bn=False):
        super(DeepLabRegressor, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        input_planes = {
            'resnet': 2048,
            'mobilenet': 320,
        }

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.classifier = nn.Sequential(
            nn.Linear(input_planes[backbone], 128),
            nn.ReLU(),
            nn.Linear(128, 4),
            nn.Sigmoid(),
        )

        if freeze_bn:
            self.freeze_bn()
コード例 #13
0
    def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
                 sync_bn=True, freeze_bn=False, pretrain=True):
        super(DeepLabX, self).__init__()
        self.num_classes = num_classes
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        # self.aspp = build_psp()
        # self.aspp = build_naiveGCE()
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)

        if freeze_bn:
            self.freeze_bn()

        if pretrain:
            self._load_pretrain()

        # change the last inference layer for binary segmentation mask
        last_conv = list(self.decoder.last_conv.children())
        self.decoder.last_conv = nn.Sequential(*last_conv[:-1])
        self.decoder.last_conv.add_module('8', nn.Conv2d(256, 2, 1, 1))
コード例 #14
0
    def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
                 sync_bn=True, freeze_bn=False):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)
        # self.last_conv = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
        #                                BatchNorm(256),
        #                                nn.ReLU(),
        #                                nn.Dropout(0.5),
        #                                nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
        #                                BatchNorm(256),
        #                                nn.ReLU(),
        #                                nn.Dropout(0.1),
        #                                nn.Conv2d(256, 2, kernel_size=1, stride=1))

        self._init_weight()
        if freeze_bn:
            self.freeze_bn()
コード例 #15
0
    def __init__(self, backbone='resnet',output_stride=16, num_classes=1,nInputChannels=5,freeze_bn=False):
        super(IOG_loop, self).__init__()
        output_shape = 128
        BatchNorm = nn.BatchNorm2d        
        channel_settings = [512, 1024, 512, 256]#[2048, 1024, 512, 256]
        self.global_net = globalNet(channel_settings, output_shape, num_classes)
        self.refine_net = refineNet(channel_settings[-1], output_shape, num_classes)        
        self.backbone = build_backbone(backbone, output_stride, BatchNorm,nInputChannels)
        self.psp4 = PSPModule(in_features=2048+64, out_features=512, sizes=(1, 2, 3, 6))      
        self.ex_points = nn.Sequential(nn.Conv2d(2, 64, kernel_size=3, stride=2, padding=1, bias=False),
                                       nn.BatchNorm2d(64),
                                       nn.ReLU(),

                                       nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
                                       nn.BatchNorm2d(128),
                                       nn.ReLU(),

                                       nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
                                       nn.BatchNorm2d(256),
                                       nn.ReLU(),


                                       nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=False),
                                       nn.BatchNorm2d(256),
                                       nn.ReLU(),

                                       nn.Conv2d(256, 64, kernel_size=1, stride=1, bias=False),
                                       nn.BatchNorm2d(64),
                                       nn.ReLU())
        if freeze_bn:
            self.freeze_bn()
コード例 #16
0
 def __init__(self,
              backbone='resnet101',
              output_stride=16,
              num_classes=21,
              sync_bn=True,
              freeze_bn=False,
              enable_interpolation=True,
              pretrained_path=None,
              norm_layer=nn.BatchNorm2d,
              enable_aspp=True):
     super(DeepLab, self).__init__()
     self.enable_aspp = enable_aspp
     if backbone == 'drn':
         output_stride = 8
     BatchNorm = norm_layer
     self.backbone = build_backbone(backbone,
                                    output_stride,
                                    BatchNorm,
                                    pretrained_path=pretrained_path)
     self.aspp = build_aspp(backbone,
                            output_stride,
                            BatchNorm,
                            enable_aspp=self.enable_aspp)
     self.decoder = build_decoder(num_classes, backbone, BatchNorm)
     self.enable_interpolation = enable_interpolation
コード例 #17
0
ファイル: SCNN.py プロジェクト: mukaman84/pytorch-template
    def __init__(self,
                 backbone='resnet',
                 output_stride=16,
                 nclass=19,
                 cuda=True,
                 extension=None):
        super(SCNN, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        BatchNorm = nn.BatchNorm2d
        if cuda == True:
            device = "cuda"
        else:
            device = "cpu"
        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.rnn = build_rnn(32, device)
        self.conv0 = nn.Sequential(
            nn.Conv2d(2048, 32, 3, padding=1, bias=False), BatchNorm(32),
            nn.ReLU())
        self.conv1 = nn.Sequential(
            nn.Conv2d(32, nclass, 3, padding=1, bias=False), BatchNorm(nclass),
            nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(32, 2, 3, padding=1, bias=False),
                                   BatchNorm(2), nn.Softmax())
        self.dropout = nn.Dropout2d()
        self.extension = build_extension(ext=extension,
                                         out_channels=nclass,
                                         kernel_size=3,
                                         padding=1,
                                         n_resblocks=3)
コード例 #18
0
    def __init__(self,
                 backbone='resnet',
                 output_stride=16,
                 num_classes=19,
                 use_ABN=True,
                 freeze_bn=False,
                 args=None,
                 separate=False):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if use_ABN:
            BatchNorm = ABN
        else:
            BatchNorm = NaiveBN

        self.backbone = build_backbone(backbone, output_stride, BatchNorm,
                                       args)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm, args,
                               separate)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm, args,
                                     separate)

        if freeze_bn:
            self.freeze_bn()
コード例 #19
0
ファイル: deeplab.py プロジェクト: mapooon/signate_AI_Edge
    def __init__(self,
                 backbone='seresnext101',
                 output_stride=16,
                 num_classes=5,
                 sync_bn=True,
                 freeze_bn=False):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        #print('bacbone')
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        #print('aspp')
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)
        #print('decoder')
        #self.up = DeconvBlock(304,256,BatchNorm=BatchNorm,n_iter=2)
        #self.last_conv = nn.Conv2d(256, num_classes, kernel_size=1, stride=1)

        if freeze_bn:
            self.freeze_bn()
コード例 #20
0
ファイル: deeplab.py プロジェクト: aimi-lab/unet_region
    def __init__(self,
                 backbone='resnet',
                 output_stride=16,
                 num_classes=21,
                 sync_bn=True,
                 freeze_bn=False,
                 cp_path=None):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)

        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(21, backbone, BatchNorm)

        if (cp_path is not None):
            cp = torch.load(cp_path, map_location='cpu')
            self.load_state_dict(cp['state_dict'])

        if freeze_bn:
            self.freeze_bn()

        # predict two classes (background / foreground)
        self.decoder.last_conv[-1] = nn.Conv2d(256, 2, kernel_size=1, stride=1)
コード例 #21
0
 def test_vgg(self):
     vgg = build_backbone(cfg)
     model = build_fcn_model(cfg)
     print(model.backbone.conv1_1.weight[0, 0, 0, 0])
     # x = torch.randn(5, 3, 224, 224)
     # y = model(x)
     from IPython import embed
     embed()
コード例 #22
0
    def __init__(self, cfg):
        super(DeformConvRCNN, self).__init__()

        self.backbone = build_backbone(cfg)
        self.rpn = build_rpn(cfg)
        self.roi_heads = build_roi_heads(cfg)
        self.mimicking_head = Mimicking_head(cfg, self.backbone,
                                             self.roi_heads)
コード例 #23
0
 def __init__(self, backbone='resnet',output_stride=16, num_classes=1,nInputChannels=5,freeze_bn=False):
     super(IOG, self).__init__()
     output_shape = 128
     channel_settings = [512, 1024, 512, 256]
     BatchNorm = nn.BatchNorm2d
     self.global_net = globalNet(channel_settings, output_shape, num_classes)
     self.refine_net = refineNet(channel_settings[-1], output_shape, num_classes)        
     self.backbone = build_backbone(backbone, output_stride, BatchNorm,nInputChannels)
     self.psp4 = PSPModule(in_features=2048, out_features=512, sizes=(1, 2, 3, 6), n_classes=256)
     if freeze_bn:
         self.freeze_bn()
コード例 #24
0
    def __init__(self, backbone:str,BatchNorm,in_channels=[512,512,512,512],pretrained=True,upsample_ratio=8,n_class=21):

        super(FCN, self).__init__()

        backbone_model=build_backbone(backbone=backbone,BatchNorm=BatchNorm,output_stride=16)

        if(backbone=='mobilenet'):
            in_channels=[320,32,24,16]
        if(upsample_ratio==8):
            self.model=FCN8(backbone=backbone_model,BatchNorm=BatchNorm,in_channels=in_channels,num_classes=n_class)
        else:
            raise NotImplementedError
コード例 #25
0
ファイル: deeplab.py プロジェクト: yimengli46/Deeplabv3
    def __init__(self,
                 backbone='resnet',
                 output_stride=16,
                 num_classes=21,
                 freeze_bn=False):
        super(DeepLab, self).__init__()

        BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)

        self.freeze_bn = freeze_bn
コード例 #26
0
    def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
                 sync_bn=True, freeze_bn=False):
        super(Deeplabv3, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)

        self.freeze_bn = freeze_bn
コード例 #27
0
    def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
                 sync_bn=True, freeze_bn=False):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)

        if freeze_bn:
            self.freeze_bn()
コード例 #28
0
    def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
                 sync_bn=True, freeze_bn=False):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        #self.deconv1 = nn.ConvTranspose2d(21, 21, 1, 4, 0, 0, bias=True)
        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)

        if freeze_bn:
            self.freeze_bn()
コード例 #29
0
def build_skmtnet(backbone: str,
                  auxiliary_head,
                  trunk_head,
                  num_classes,
                  output_stride=32,
                  sync_bn=False):
    """
    :param backbone:the name of backbone
    :param auxiliary_head:
    :param trunk_head:
    :param num_classes:
    :param output_stride:
    :param sync_bn:
    :return:
    """
    #选择BN方式
    if sync_bn:
        BatchNorm = SynchronizedBatchNorm2d
    else:
        BatchNorm = nn.BatchNorm2d
    #选择backbone
    if (backbone):
        backbone_model = build_backbone(backbone, output_stride, BatchNorm,
                                        num_classes)
    else:
        backbone_model = None
    #选择auxiliary_head
    if (auxiliary_head):
        auxiliary_head_model = build_auxiliary_head(auxiliary_head, backbone,
                                                    BatchNorm, output_stride,
                                                    num_classes)
    else:
        auxiliary_head_model = None

    #选择trunk head
    trunk_head_model = build_head(trunk_head,
                                  backbone,
                                  BatchNorm,
                                  output_stride=output_stride,
                                  num_classes=num_classes)
    #集成模型
    return SkmtNet(backbone_model, auxiliary_head_model, trunk_head_model,
                   num_classes)
コード例 #30
0
    def __init__(self, backbone='resnet', output_stride=16, num_classes=21):
        self.inplanes = 64
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8
            inplanes = 512
        elif backbone == 'mobilenet':
            inplanes = 320
        else:
            inplanes = 2048

        #if sync_bn == True:
        #    BatchNorm = SynchronizedBatchNorm2d
        #else:
        BatchNorm = nn.BatchNorm2d
        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.head = self._make_pred_layer(Classifier_Module, inplanes,
                                          [6, 12, 18, 24], [6, 12, 18, 24],
                                          num_classes)
コード例 #31
0
    def __init__(self,
                 backbone='resnet',
                 output_stride=16,
                 sync_bn=True,
                 freeze_bn=False):
        super().__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(5, backbone, BatchNorm)

        if freeze_bn:
            self.freeze_bn()