示例#1
0
class PSPNet(nn.Sequential):
    def __init__(self, configer):
        super(PSPNet, self).__init__()
        self.configer = configer
        self.num_classes = self.configer.get('data', 'num_classes')
        self.backbone = BackboneSelector(configer).get_backbone()

        num_features = self.backbone.get_num_features()

        self.low_features = nn.Sequential(
            self.backbone.conv1, self.backbone.bn1, self.backbone.relu,
            self.backbone.maxpool,
            self.backbone.layer1,
        )

        self.high_features1 = nn.Sequential(self.backbone.layer2, self.backbone.layer3)
        self.high_features2 = nn.Sequential(self.backbone.layer4)
        self.decoder = PPMBilinearDeepsup(num_class=self.num_classes, fc_dim=num_features,
                                          bn_type=self.configer.get('network', 'bn_type'))

    def forward(self, x_):
        low = self.low_features(x_)
        aux = self.high_features1(low)
        x = self.high_features2(aux)
        x, aux = self.decoder([x, aux])
        x = F.interpolate(x, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=False)

        return x, aux
示例#2
0
class PSPNet(nn.Sequential):
    def __init__(self, configer):
        super(PSPNet, self).__init__()
        self.configer = configer
        self.num_classes = self.configer.get('data', 'num_classes')
        self.backbone = BackboneSelector(configer).get_backbone()
        num_features = self.backbone.get_num_features()
        self.dsn = nn.Sequential(
            _ConvBatchNormReluBlock(num_features // 2, num_features // 4, 3, 1,
                                    bn_type=self.configer.get('network', 'bn_type')),
            nn.Dropout2d(0.1),
            nn.Conv2d(num_features // 4, self.num_classes, 1, 1, 0)
        )
        self.ppm = PPMBilinearDeepsup(fc_dim=num_features, bn_type=self.configer.get('network', 'bn_type'))

        self.cls = nn.Sequential(
            nn.Conv2d(num_features + 4 * 512, 512, kernel_size=3, padding=1, bias=False),
            ModuleHelper.BNReLU(512, bn_type=self.configer.get('network', 'bn_type')),
            nn.Dropout2d(0.1),
            nn.Conv2d(512, self.num_classes, kernel_size=1)
        )

    def forward(self, x_):
        x = self.backbone(x_)
        aux_x = self.dsn(x[-2])
        x = self.ppm(x[-1])
        x = self.cls(x)
        aux_x = F.interpolate(aux_x, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True)
        x = F.interpolate(x, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True)
        return aux_x, x
示例#3
0
class DenseASPP(nn.Module):
    """
    * output_scale can only set as 8 or 16
    """
    def __init__(self, configer):
        super(DenseASPP, self).__init__()
        self.configer = configer
        dropout0 = 0.1
        dropout1 = 0.1

        self.backbone = BackboneSelector(configer).get_backbone()

        num_features = self.backbone.get_num_features()

        self.trans = _Transition(num_input_features=self.num_features,
                                 num_output_features=self.num_features // 2,
                                 bn_type=self.configer.get(
                                     'network', 'bn_type'))

        self.num_features = self.num_features // 2

        self.ASPP_3 = _DenseAsppBlock(input_num=num_features,
                                      num1=256,
                                      num2=64,
                                      dilation_rate=3,
                                      drop_out=dropout0,
                                      bn_type=self.configer.get(
                                          'network', 'bn_type'))

        self.ASPP_6 = _DenseAsppBlock(input_num=num_features + 64 * 1,
                                      num1=256,
                                      num2=64,
                                      dilation_rate=6,
                                      drop_out=dropout0,
                                      bn_type=self.configer.get(
                                          'network', 'bn_type'))

        self.ASPP_12 = _DenseAsppBlock(input_num=num_features + 64 * 2,
                                       num1=256,
                                       num2=64,
                                       dilation_rate=12,
                                       drop_out=dropout0,
                                       bn_type=self.configer.get(
                                           'network', 'bn_type'))

        self.ASPP_18 = _DenseAsppBlock(input_num=num_features + 64 * 3,
                                       num1=256,
                                       num2=64,
                                       dilation_rate=18,
                                       drop_out=dropout0,
                                       bn_type=self.configer.get(
                                           'network', 'bn_type'))

        self.ASPP_24 = _DenseAsppBlock(input_num=num_features + 64 * 4,
                                       num1=256,
                                       num2=64,
                                       dilation_rate=24,
                                       drop_out=dropout0,
                                       bn_type=self.configer.get(
                                           'network', 'bn_type'))

        num_features = num_features + 5 * 64

        self.classification = nn.Sequential(
            nn.Dropout2d(p=dropout1),
            nn.Conv2d(in_channels=num_features,
                      out_channels=self.configer.get('network',
                                                     'out_channels'),
                      kernel_size=1,
                      padding=0))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight.data)

            elif isinstance(
                    m,
                    ModuleHelper.BatchNorm2d(
                        bn_type=self.configer.get('network', 'bn_type'))):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        feature = self.backbone(x)

        aspp3 = self.ASPP_3(feature)
        feature = torch.cat((aspp3, feature), dim=1)

        aspp6 = self.ASPP_6(feature)
        feature = torch.cat((aspp6, feature), dim=1)

        aspp12 = self.ASPP_12(feature)
        feature = torch.cat((aspp12, feature), dim=1)

        aspp18 = self.ASPP_18(feature)
        feature = torch.cat((aspp18, feature), dim=1)

        aspp24 = self.ASPP_24(feature)
        feature = torch.cat((aspp24, feature), dim=1)

        cls = self.classification(feature)

        out = F.interpolate(cls,
                            scale_factor=8,
                            mode='bilinear',
                            align_corners=True)

        return out