Ejemplo n.º 1
0
    def __init__(self, num_classes, in_channels=3, backbone='resnet152', pretrained=True, use_aux=True, freeze_bn=False, freeze_backbone=False):
        super(PSPNet, self).__init__()
        # TODO: Use synch batchnorm
        norm_layer = nn.BatchNorm2d
        model = getattr(resnet, backbone)(pretrained, norm_layer=norm_layer, )
        m_out_sz = model.fc.in_features
        self.use_aux = use_aux 

        self.initial = nn.Sequential(*list(model.children())[:4])
        if in_channels != 3:
            self.initial[0] = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.initial = nn.Sequential(*self.initial)
        
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4

        self.master_branch = nn.Sequential(
            _PSPModule(m_out_sz, bin_sizes=[1, 2, 3, 6], norm_layer=norm_layer),
            nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1)
        )

        self.auxiliary_branch = nn.Sequential(
            nn.Conv2d(m_out_sz//2, m_out_sz//4, kernel_size=3, padding=1, bias=False),
            norm_layer(m_out_sz//4),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1)
        )

        initialize_weights(self.master_branch, self.auxiliary_branch)
        if freeze_bn: self.freeze_bn()
        if freeze_backbone: 
            set_trainable([self.initial, self.layer1, self.layer2, self.layer3, self.layer4], False)
Ejemplo n.º 2
0
    def __init__(self,
                 num_classes,
                 in_channels=3,
                 backbone='xception',
                 pretrained=True,
                 output_stride=16,
                 freeze_bn=False,
                 freeze_backbone=False,
                 **_):

        super(DeepLab, self).__init__()
        assert ('xception' or 'resnet' in backbone)
        if 'resnet' in backbone:
            self.backbone = ResNet(in_channels=in_channels,
                                   output_stride=output_stride,
                                   pretrained=pretrained)
            low_level_channels = 256
        else:
            self.backbone = xception_65(
                output_stride=output_stride,
                pretrained=pretrained,
                global_pool=False,
                checkpoint_path='./pretrained/xception_65.pth')
            low_level_channels = 128

        self.ASSP = ASSP(in_channels=2048, output_stride=output_stride)
        self.decoder = Decoder(low_level_channels, num_classes)

        if freeze_bn: self.freeze_bn()
        if freeze_backbone:
            set_trainable([self.backbone], False)
Ejemplo n.º 3
0
    def __init__(
        self,
        num_classes,
        in_channels=3,
        backbone="xception",
        pretrained=True,
        output_stride=16,
        **kwargs,
    ):

        super(DeepLab, self).__init__()
        assert "xception" or "resnet" in backbone
        if "resnet" in backbone:
            self.backbone = ResNet(
                in_channels=in_channels,
                output_stride=output_stride,
                pretrained=pretrained,
            )
            low_level_channels = 256
        else:
            self.backbone = Xception(output_stride=output_stride,
                                     pretrained=pretrained)
            low_level_channels = 128

        self.ASSP = ASSP(in_channels=2048, output_stride=output_stride)
        self.decoder = Decoder(low_level_channels, num_classes)

        # unpack kwargs
        freeze_bn = kwargs["freeze_bn"]
        freeze_backbone = kwargs["freeze_backbone"]
        if freeze_bn:
            self.freeze_bn()
        if freeze_backbone:
            set_trainable([self.backbone], False)
Ejemplo n.º 4
0
    def __init__(self,
                 num_classes,
                 in_channels=3,
                 backbone="resnet50",
                 pretrained=True,
                 freeze_bn=False,
                 freeze_backbone=False,
                 **kwargs):
        super(UNetResnet, self).__init__()
        model = getattr(resnet, backbone)(pretrained,
                                          norm_layer=nn.BatchNorm2d)

        self.initial = list(model.children())[:4]
        if in_channels != 3:
            self.initial[0] = nn.Conv2d(in_channels,
                                        64,
                                        kernel_size=7,
                                        stride=2,
                                        padding=3,
                                        bias=False)
        self.initial = nn.Sequential(*self.initial)

        # encoder
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4

        # decoder
        self.conv1 = nn.Conv2d(2048, 192, kernel_size=3, stride=1, padding=1)
        self.upconv1 = nn.ConvTranspose2d(192, 128, 4, 2, 1, bias=False)

        self.conv2 = nn.Conv2d(1152, 128, kernel_size=3, stride=1, padding=1)
        self.upconv2 = nn.ConvTranspose2d(128, 96, 4, 2, 1, bias=False)

        self.conv3 = nn.Conv2d(608, 96, kernel_size=3, stride=1, padding=1)
        self.upconv3 = nn.ConvTranspose2d(96, 64, 4, 2, 1, bias=False)

        self.conv4 = nn.Conv2d(320, 64, kernel_size=3, stride=1, padding=1)
        self.upconv4 = nn.ConvTranspose2d(64, 48, 4, 2, 1, bias=False)

        self.conv5 = nn.Conv2d(48, 48, kernel_size=3, stride=1, padding=1)
        self.upconv5 = nn.ConvTranspose2d(48, 32, 4, 2, 1, bias=False)

        self.conv6 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(32, num_classes, kernel_size=1, bias=False)

        initialize_weights(self)

        if freeze_bn:
            self.freeze_bn()
        if freeze_backbone:
            set_trainable(
                [
                    self.initial, self.layer1, self.layer2, self.layer3,
                    self.layer4
                ],
                False,
            )
Ejemplo n.º 5
0
    def __init__(self,
                 num_classes,
                 in_channels=3,
                 backbone='resnet152',
                 pretrained=True,
                 use_aux=True,
                 freeze_bn=False,
                 freeze_backbone=False):
        super().__init__()
        # TODO: Use synch batchnorm
        norm_layer = nn.BatchNorm2d
        # model = getattr(resnet, backbone)(pretrained, norm_layer=norm_layer)
        model = getattr(models, backbone)(pretrained)

        self.initial = nn.Sequential(*list(model.children())[:4])
        if in_channels != 3:
            self.initial[0] = nn.Conv2d(in_channels,
                                        64,
                                        kernel_size=7,
                                        stride=2,
                                        padding=3,
                                        bias=False)
        self.initial = nn.Sequential(*self.initial)

        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4
        out_size1 = model.layer1[-1].bn3.num_features
        out_size2 = model.layer2[-1].bn3.num_features
        out_size3 = model.layer3[-1].bn3.num_features
        out_size4 = model.layer4[-1].bn3.num_features

        bin_sizes = [1, 2, 3, 6]
        self.ppm4 = nn.Sequential(
            PPM(out_size4, bin_sizes, norm_layer=norm_layer),
            nn.Conv2d(out_size4 // len(bin_sizes), num_classes, kernel_size=1))
        self.ppm3 = nn.Sequential(
            PPM(out_size3, bin_sizes, norm_layer=norm_layer),
            nn.Conv2d(out_size3 // len(bin_sizes), num_classes, kernel_size=1))
        self.ppm2 = nn.Sequential(
            PPM(out_size2, bin_sizes, norm_layer=norm_layer),
            nn.Conv2d(out_size2 // len(bin_sizes), num_classes, kernel_size=1))

        self.daum = DAUM(in_channels, (9, 9))
        self.daum1 = AUM(in_channels, (5, 5))
        self.daum2 = AUM(in_channels, (5, 5))
        self.daum3 = AUM(in_channels, (5, 5))
        # self.daum4 = DAUM(512, (5, 5))

        self.smoother = GaussianSmoother()

        initialize_weights(self.ppm4, self.ppm3, self.ppm2, self.daum1,
                           self.daum2, self.daum3)
        if freeze_bn:
            self.freeze_bn()
        if freeze_backbone:
            set_trainable([
                self.initial, self.layer1, self.layer2, self.layer3,
                self.layer4
            ], False)