예제 #1
0
    def __init__(self,
                 backbone,
                 BatchNorm,
                 output_stride,
                 num_classes,
                 freeze_bn=False):
        super(SplitDeepLabDANet, self).__init__()
        self.backbone = backbone
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder1 = build_decoder(num_classes, backbone, BatchNorm)
        self.decoder2 = build_decoder(num_classes, backbone, BatchNorm)
        self.decoder3 = build_decoder(num_classes, backbone, BatchNorm)
        self.decoder4 = build_decoder(num_classes, backbone, BatchNorm)
        self.decoder5 = build_decoder(num_classes, backbone, BatchNorm)
        self.decoders = [
            self.decoder1, self.decoder2, self.decoder3, self.decoder4,
            self.decoder5
        ]
        self.output_stride = output_stride

        self.backbone = backbone
        in_channels = get_inchannels(self.backbone)
        self.head = DANetHead(in_channels[0], num_classes, BatchNorm)

        self.output_stride = output_stride
        if freeze_bn:
            self.freeze_bn()
예제 #2
0
 def __init__(self, backbone,BatchNorm, output_stride, num_classes,freeze_bn=False):
     super(DeepLab, self).__init__()
     self.backbone = backbone
     self.aspp = build_aspp(backbone, output_stride, BatchNorm)
     self.decoder = build_decoder(num_classes, backbone, BatchNorm)
     self.output_stride = output_stride
     if freeze_bn:
         self.freeze_bn()
예제 #3
0
    def __init__(self,
                 backbone='resnet',
                 output_stride=16,
                 num_classes=21,
                 sync_bn=True,
                 freeze_bn=False):
        super(DeepLab, self).__init__()

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)

        if freeze_bn:
            self.freeze_bn()
예제 #4
0
    def __init__(self,
                 backbone,
                 BatchNorm,
                 output_stride,
                 num_classes,
                 freeze_bn=False):
        super(DeepDran, self).__init__()
        self.backbone = backbone
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)
        if (backbone in ["resnet50", "resnet101"]):
            in_channels = 2048
            in_channels_seg = 256
        else:
            raise NotImplementedError

        self.head = DranHead(in_channels, num_classes, BatchNorm)
        self.cls_seg = nn.Sequential(
            nn.Dropout2d(0.1, False), nn.Conv2d(in_channels_seg, num_classes,
                                                1))
        self.output_stride = output_stride
        if freeze_bn:
            self.freeze_bn()