Example #1
0
    def __init__(self, channel=32):
        super(PolypPVT, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        path = 'F:\pretrain\pvt_v2_b3.pth'
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {
            k: v
            for k, v in save_model.items() if k in model_dict.keys()
        }
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)

        self.Translayer2_0 = BasicConv2d(64, channel, 1)
        self.Translayer2_1 = BasicConv2d(128, channel, 1)
        self.Translayer3_1 = BasicConv2d(320, channel, 1)
        self.Translayer4_1 = BasicConv2d(512, channel, 1)

        self.CFM = CFM(channel)
        self.ca = ChannelAttention(64)
        self.sa = SpatialAttention()
        self.SAM = SAM()

        self.down05 = nn.Upsample(scale_factor=0.5,
                                  mode='bilinear',
                                  align_corners=True)
        self.out_SAM = nn.Conv2d(channel, 1, 1)
        self.out_CFM = nn.Conv2d(channel, 1, 1)
Example #2
0
    def __init__(self, channel=32):
        super(MyNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        path = 'F:\pretrain\pvt_v2_b3.pth'
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)

        self.Translayer1= BasicConv2d(64, channel, 1)
        self.Translayer2 = BasicConv2d(128, channel, 1)
        self.Translayer3 = BasicConv2d(320, channel, 1)
        self.Translayer4 = BasicConv2d(512, channel, 1)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.out1_1 =  BasicConv2d(channel*3, channel, 1)
        self.out1_2 =  BasicConv2d(channel*3, channel, 1)
        self.out1_3 =   BasicConv2d(channel*3, channel, 1)
        self.out1_4 =   BasicConv2d(channel, channel, 1)


       # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder4 = DecoderBlock(in_channels=channel, out_channels=channel)
        self.decoder3 = DecoderBlock(in_channels=channel*2, out_channels=channel)
        self.decoder2 = DecoderBlock(in_channels=channel*2, out_channels=channel)
        self.decoder1 = nn.Sequential(BasicConv2d(channel*2, channel,1),
                                 BasicConv2d(channel, channel,1))

        # adaptive Flusion module
        self.asm3 = ASM(channel, channel*3)
        self.asm2 = ASM(channel, channel*3)
        self.asm1 = ASM(channel, channel*3)

        self.unetout1 =  nn.Conv2d(channel, 1, 1)
        self.unetout2 =  nn.Conv2d(channel, 1, 1)

        self.conv_concat2 = BasicConv2d(2 * channel, channel, 3, padding=1)
        self.conv_concat3 = BasicConv2d(2 * channel, channel, 3, padding=1)
        self.conv_concat4 = BasicConv2d(2* channel,  channel, 3, padding=1)  # 最大 64*4 = 256 不大

        self.selayer = SELayer(channel)

        self.cat1 =CasAtt(channel)
        self.cat2 =CasAtt(channel)
        self.cat3 =CasAtt(channel)
        self.cat4 =CasAtt(channel)
        self.CFM = CFM(channel)
        self.con4 = BasicConv2d(channel,channel,1)
        self.con3 = BasicConv2d(channel*2,channel,1)
        self.con2 = BasicConv2d(channel*2,channel,1)
        self.con1 = BasicConv2d(channel*2,channel,1)

        self.sideout4 = BasicConv2d(channel,1,1)
        self.sideout3 = BasicConv2d(channel,1,1)
        self.sideout2 = BasicConv2d(channel,1,1)
        self.sideout1 = BasicConv2d(channel,1,1)
Example #3
0
    def __init__(self, channel=32):
        super(BiDFNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        path = 'F:\pretrain\pvt_v2_b3.pth'
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {
            k: v
            for k, v in save_model.items() if k in model_dict.keys()
        }
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)

        self.Translayer1 = BasicConv2d(64, channel, 1)
        self.Translayer2 = BasicConv2d(128, channel, 1)
        self.Translayer3 = BasicConv2d(320, channel, 1)
        self.Translayer4 = BasicConv2d(512, channel, 1)

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)

        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder4 = DecoderBlock(in_channels=channel, out_channels=channel)
        self.decoder3 = DecoderBlock(in_channels=channel, out_channels=channel)
        self.decoder2 = DecoderBlock(in_channels=channel, out_channels=channel)
        self.decoder1 = nn.Sequential(BasicConv2d(channel, channel, 1))
        self.decoder5 = nn.Sequential(BasicConv2d(channel * 2, channel, 1))
        self.decoder6 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel,
                                     doubleconv=False)
        self.decoder7 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel,
                                     doubleconv=False)
        self.decoder8 = DecoderBlock(in_channels=channel,
                                     out_channels=channel,
                                     doubleconv=False)

        # adaptive Flusion module

        self.afm3 = FSM(channel, channel)
        self.afm2 = FSM(channel, channel)
        self.afm1 = FSM(channel, channel)
        self.afm4 = FSM(channel, channel)

        self.rcm1 = RCM(channel, channel)
        self.rcm2 = RCM(channel, channel)
        self.rcm3 = RCM(channel, channel)
        self.rcm4 = RCM(channel, channel)

        self.fluse3 = CAB(channel)
        self.fluse2 = CAB(channel)
        self.fluse1 = CAB(channel)

        self.unetout1 = nn.Conv2d(channel, 1, 1)
        self.unetout2 = nn.Conv2d(channel, 1, 1)
Example #4
0
    def __init__(self, channel=32):
        super(MyNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        path = 'F:\pretrain\pvt_v2_b3.pth'
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {
            k: v
            for k, v in save_model.items() if k in model_dict.keys()
        }
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)

        self.Translayer1 = BasicConv2d(64, channel, 1)
        self.Translayer2 = BasicConv2d(128, channel, 1)
        self.Translayer3 = BasicConv2d(320, channel, 1)
        self.Translayer4 = BasicConv2d(512, channel, 1)

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)

        self.down01 = nn.Upsample(scale_factor=1 / 2,
                                  mode='bilinear',
                                  align_corners=True)
        self.down02 = nn.Upsample(scale_factor=1 / 4,
                                  mode='bilinear',
                                  align_corners=True)
        self.down03 = nn.Upsample(scale_factor=1 / 8,
                                  mode='bilinear',
                                  align_corners=True)
        self.down04 = nn.Upsample(scale_factor=1 / 16,
                                  mode='bilinear',
                                  align_corners=True)

        self.upsample1 = nn.Upsample(scale_factor=32,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=16,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8,
                                     mode='bilinear',
                                     align_corners=True)

        self.out1_1 = BasicConv2d(channel * 3, channel, 1)
        self.out1_2 = BasicConv2d(channel * 3, channel, 1)
        self.out1_3 = BasicConv2d(channel * 3, channel, 1)
        self.out1_4 = BasicConv2d(channel, channel, 1)

        self.refineconv = BasicConv2d(3, 1, 1)
        self.refine = RefineUNet(1, 1)

        self.ca = ChannelAttention(channel)
        self.sa = SpatialAttention()
        self.outatte = nn.Conv2d(channel, channel, 1)

        # Decoder
        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder4 = DecoderBlock(in_channels=channel, out_channels=channel)
        self.decoder3 = DecoderBlock(in_channels=channel * 3,
                                     out_channels=channel)
        self.decoder2 = DecoderBlock(in_channels=channel * 3,
                                     out_channels=channel)
        self.decoder1 = nn.Sequential(BasicConv2d(channel * 3, channel, 1),
                                      BasicConv2d(channel, channel, 1))

        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder2_4 = DecoderBlock(in_channels=channel,
                                       out_channels=channel)
        self.decoder2_3 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=channel)
        self.decoder2_2 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=channel)
        self.decoder2_1 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=channel)

        # adaptive selection module
        self.asm4 = ASM(channel, channel * 3)
        self.asm3 = ASM(channel, channel * 3)
        self.asm2 = ASM(channel, channel * 3)
        self.asm1 = ASM(channel, channel * 3)

        self.unetout1 = nn.Conv2d(channel, 1, 1)
        self.unetout2 = nn.Conv2d(channel, 1, 1)

        self.COM = COM(channel)
        self.cobv1 = BasicConv2d(3 * channel, channel, 1)
        self.cobv2 = BasicConv2d(3 * channel, channel, 1)
        self.nocal = NonLocalBlock(channel)
        self.selayer = SELayer(channel)
Example #5
0
    def __init__(self, channel=64):
        super(MyNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        path = 'F:\pretrain\pvt_v2_b3.pth'
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {
            k: v
            for k, v in save_model.items() if k in model_dict.keys()
        }
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)

        self.Translayer2_0 = BasicConv2d(64, channel, 1)
        self.Translayer2_1 = BasicConv2d(128, channel, 1)
        self.Translayer3_1 = BasicConv2d(320, channel, 1)
        self.Translayer4_1 = BasicConv2d(512, channel, 1)

        # self.CFM = CFM(channel)
        #
        # self.SAM = SAM()

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)

        self.down05 = nn.Upsample(scale_factor=0.5,
                                  mode='bilinear',
                                  align_corners=True)
        self.out_SAM = nn.Conv2d(channel, 1, 1)
        self.out_CFM = nn.Conv2d(channel, 1, 1)

        self.rf1 = RF(512 + 320 + 128, channel)
        self.rf2 = RF(512 + 320 + 128, channel)
        self.rf3 = RF(512 + 320 + 128, channel)

        self.att1 = BCA(channel, channel, channel)
        self.att2 = BCA(channel, channel, channel)
        self.att3 = BCA(channel, channel, channel)

        self.out_SAM = nn.Conv2d(channel, 1, 1)
        self.out_CFM = nn.Conv2d(channel, 1, 1)

        self.down01 = nn.Upsample(scale_factor=1 / 8,
                                  mode='bilinear',
                                  align_corners=True)
        self.down02 = nn.Upsample(scale_factor=1 / 4,
                                  mode='bilinear',
                                  align_corners=True)
        self.down03 = nn.Upsample(scale_factor=1 / 2,
                                  mode='bilinear',
                                  align_corners=True)

        self.upsample1 = nn.Upsample(scale_factor=32,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=16,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8,
                                     mode='bilinear',
                                     align_corners=True)

        self.out1 = BasicConv2d(channel, 1, 1)
        self.out2 = BasicConv2d(channel, 1, 1)
        self.out3 = BasicConv2d(channel, 1, 1)
        self.out4 = BasicConv2d(channel, 1, 1)
        self.refineconv = BasicConv2d(3, 1, 1)
        self.refine = RefineUNet(1, 1)

        self.ca = ChannelAttention(channel)
        self.sa = SpatialAttention()
        self.outatte = nn.Conv2d(channel, channel, 1)

        self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample5 = BasicConv2d(2 * channel,
                                          2 * channel,
                                          3,
                                          padding=1)
Example #6
0
    def __init__(self, channel=64):
        super(MyNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        path = 'F:\pretrain\pvt_v2_b3.pth'
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {
            k: v
            for k, v in save_model.items() if k in model_dict.keys()
        }
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)

        self.Translayer1 = BasicConv2d(64, channel, 1)
        self.Translayer2 = BasicConv2d(128, channel, 1)
        self.Translayer3 = BasicConv2d(320, channel, 1)
        self.Translayer4 = BasicConv2d(512, channel, 1)

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)

        self.down05 = nn.Upsample(scale_factor=0.5,
                                  mode='bilinear',
                                  align_corners=True)

        self.down01 = nn.Upsample(scale_factor=1 / 2,
                                  mode='bilinear',
                                  align_corners=True)
        self.down02 = nn.Upsample(scale_factor=1 / 4,
                                  mode='bilinear',
                                  align_corners=True)
        self.down03 = nn.Upsample(scale_factor=1 / 8,
                                  mode='bilinear',
                                  align_corners=True)
        self.down04 = nn.Upsample(scale_factor=1 / 16,
                                  mode='bilinear',
                                  align_corners=True)

        self.upsample1 = nn.Upsample(scale_factor=32,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=16,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8,
                                     mode='bilinear',
                                     align_corners=True)

        self.out1_1 = BasicConv2d(64, channel, 1)
        self.out1_2 = BasicConv2d(128, channel, 1)
        self.out1_3 = BasicConv2d(320, channel, 1)
        self.out1_4 = BasicConv2d(512, channel, 1)

        self.out2_1 = BasicConv2d(64, 1, 1)
        self.out2_2 = BasicConv2d(128, 1, 1)
        self.out2_3 = BasicConv2d(320, 1, 1)
        self.out2_4 = BasicConv2d(512, 1, 1)

        self.refineconv = BasicConv2d(3, 1, 1)
        self.refine = RefineUNet(1, 1)

        self.ca = ChannelAttention(channel)
        self.sa = SpatialAttention()
        self.outatte = nn.Conv2d(channel, channel, 1)

        # Decoder
        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        #  self.decoder4 = DecoderBlock(in_channels=channel, out_channels=channel)
        #  self.decoder3 = DecoderBlock(in_channels=channel*2, out_channels=channel)
        #  self.decoder2 = DecoderBlock(in_channels=channel*2, out_channels=channel)
        #  self.decoder1 = DecoderBlock(in_channels=channel*2, out_channels=channel)
        #

        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder2_4 = DecoderBlock(in_channels=channel, out_channels=320)
        self.decoder2_3 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=128)
        self.decoder2_2 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=64)
        self.decoder2_1 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=64)

        # adaptive selection module
        self.asm4 = ASM(channel, 512)
        self.asm3 = ASM(channel, channel * 3)
        self.asm2 = ASM(channel, channel * 3)
        self.asm1 = ASM(channel, channel * 3)

        self.unetout1 = nn.Sequential(
            BasicConv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(32, 1, 1))
        self.unetout1 = nn.Sequential(
            BasicConv2d(channel, channel, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(channel, 1, 1))

        # Sideout
        self.sideout4 = SideoutBlock(32, 1)
        self.sideout3 = SideoutBlock(32, 1)
        self.sideout2 = SideoutBlock(32, 1)
        self.sideout1 = SideoutBlock(32, 1)

        self.decoder1 = Decoder()
        self.decoder2 = Decoder()

        self.acfm3 = ACFM()
        self.acfm2 = ACFM()

        self.dgcm3 = DGCM()
        self.dgcm2 = DGCM()
        self.upconv3 = BasicConv2d(64,
                                   64,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   relu=True)

        self.classifier = nn.Conv2d(64, 1, 1)
        self.upconv2 = BasicConv2d(64,
                                   64,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   relu=True)
Example #7
0
    def __init__(self, channel=64):
        super(MyNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        self.resnet = res2net50_v1b_26w_4s(pretrained=True)

        self.Translayer1 = BasicConv2d(256, channel, 1)
        self.Translayer2 = BasicConv2d(512, channel, 1)
        self.Translayer3 = BasicConv2d(1024, channel, 1)
        self.Translayer4 = BasicConv2d(2048, channel, 1)

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)

        self.down05 = nn.Upsample(scale_factor=0.5,
                                  mode='bilinear',
                                  align_corners=True)

        self.down01 = nn.Upsample(scale_factor=1 / 2,
                                  mode='bilinear',
                                  align_corners=True)
        self.down02 = nn.Upsample(scale_factor=1 / 4,
                                  mode='bilinear',
                                  align_corners=True)
        self.down03 = nn.Upsample(scale_factor=1 / 8,
                                  mode='bilinear',
                                  align_corners=True)
        self.down04 = nn.Upsample(scale_factor=1 / 16,
                                  mode='bilinear',
                                  align_corners=True)

        self.upsample1 = nn.Upsample(scale_factor=32,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=16,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8,
                                     mode='bilinear',
                                     align_corners=True)

        self.out1_1 = BasicConv2d(channel * 3, channel, 1)
        self.out1_2 = BasicConv2d(channel * 3, channel, 1)
        self.out1_3 = BasicConv2d(channel * 3, channel, 1)
        self.out1_4 = BasicConv2d(channel * 3, channel, 1)

        self.out2_1 = BasicConv2d(256, 1, 1)
        self.out2_2 = BasicConv2d(512, 1, 1)
        self.out2_3 = BasicConv2d(1024, 1, 1)
        self.out2_4 = BasicConv2d(512, 1, 1)

        self.refineconv = BasicConv2d(3, 1, 1)
        self.refine = RefineUNet(1, 1)

        self.ca = ChannelAttention(channel)
        self.sa = SpatialAttention()
        self.outatte = nn.Conv2d(channel, channel, 1)

        # Decoder
        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder4 = DecoderBlock(in_channels=channel * 3,
                                     out_channels=channel)
        self.decoder3 = DecoderBlock(in_channels=channel * 3,
                                     out_channels=channel)
        self.decoder2 = DecoderBlock(in_channels=channel * 3,
                                     out_channels=channel)
        self.decoder1 = nn.Sequential(BasicConv2d(channel * 2, channel, 1),
                                      BasicConv2d(channel, channel, 1))

        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder2_4 = DecoderBlock(in_channels=channel,
                                       out_channels=channel)
        self.decoder2_3 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=channel)
        self.decoder2_2 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=channel)
        self.decoder2_1 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=channel)

        # adaptive selection module
        self.asm4 = ASM(channel, channel * 3)
        self.asm3 = ASM(channel, channel * 3)
        self.asm2 = ASM(channel, channel * 3)
        self.asm1 = ASM(channel, channel * 3)

        self.unetout1 = nn.Conv2d(channel, 1, 1)
        self.unetout2 = nn.Conv2d(channel, 1, 1)

        self.COM = COM(channel)
        self.cobv1 = BasicConv2d(3 * channel, channel, 1)
        self.cobv2 = BasicConv2d(3 * channel, channel, 1)
        self.nocal = NonLocalBlock(channel)
Example #8
0
    def __init__(self, channel=32):
        super(MyNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        path = 'F:\pretrain\pvt_v2_b3.pth'
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {
            k: v
            for k, v in save_model.items() if k in model_dict.keys()
        }
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)

        self.Translayer1 = BasicConv2d(64, channel, 1)
        self.Translayer2 = BasicConv2d(128, channel, 1)
        self.Translayer3 = BasicConv2d(320, channel, 1)
        self.Translayer4 = BasicConv2d(512, channel, 1)

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)

        self.down05 = nn.Upsample(scale_factor=0.5,
                                  mode='bilinear',
                                  align_corners=True)

        self.down01 = nn.Upsample(scale_factor=1 / 2,
                                  mode='bilinear',
                                  align_corners=True)
        self.down02 = nn.Upsample(scale_factor=1 / 4,
                                  mode='bilinear',
                                  align_corners=True)
        self.down03 = nn.Upsample(scale_factor=1 / 8,
                                  mode='bilinear',
                                  align_corners=True)
        self.down04 = nn.Upsample(scale_factor=1 / 16,
                                  mode='bilinear',
                                  align_corners=True)

        self.upsample1 = nn.Upsample(scale_factor=32,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=16,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8,
                                     mode='bilinear',
                                     align_corners=True)

        # Decoder
        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder4 = DecoderBlock(in_channels=channel, out_channels=channel)
        self.decoder3 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)
        self.decoder2 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)
        self.decoder1 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)

        # adaptive selection module
        self.asm4 = ASM(channel, 512)
        self.asm3 = ASM(channel, channel * 3)
        self.asm2 = ASM(channel, channel * 3)
        self.asm1 = ASM(channel, channel * 3)

        self.unetout1 = nn.Sequential(
            BasicConv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(32, 1, 1))
        self.unetout1 = nn.Sequential(nn.Conv2d(channel, 1, 1))

        # Sideout
        self.sideout4 = SideoutBlock(32, 1)
        self.sideout3 = SideoutBlock(32, 1)
        self.sideout2 = SideoutBlock(32, 1)
        self.sideout1 = SideoutBlock(32, 1)
        self.seb1 = SEB1(channel, channel)
        self.seb2 = SEB2(channel, channel)
        self.seb3 = SEB3(channel, channel)
        self.rrb4 = RRB(channel, channel)
        self.rrb3 = RRB(channel, channel)
        self.rrb2 = RRB(channel, channel)
        self.rrb1 = RRB(channel, channel)
Example #9
0
    def __init__(self, channel=32):
        super(MyNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        path = 'F:\pretrain\pvt_v2_b3.pth'
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {
            k: v
            for k, v in save_model.items() if k in model_dict.keys()
        }
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)

        self.Translayer1 = BasicConv2d(64, channel, 1)
        self.Translayer2 = BasicConv2d(128, channel, 1)
        self.Translayer3 = BasicConv2d(320, channel, 1)
        self.Translayer4 = BasicConv2d(512, channel, 1)
        self.Translayer5 = BasicConv2d(64, channel, 1)
        self.Translayer6 = BasicConv2d(128, channel, 1)
        self.Translayer7 = BasicConv2d(320, channel, 1)
        self.Translayer8 = BasicConv2d(512, channel, 1)

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)
        self.downsample = nn.Upsample(scale_factor=1 / 4,
                                      mode='bilinear',
                                      align_corners=True)

        self.down01 = nn.Upsample(scale_factor=1 / 2,
                                  mode='bilinear',
                                  align_corners=True)
        self.down02 = nn.Upsample(scale_factor=1 / 4,
                                  mode='bilinear',
                                  align_corners=True)
        self.down03 = nn.Upsample(scale_factor=1 / 8,
                                  mode='bilinear',
                                  align_corners=True)
        self.down04 = nn.Upsample(scale_factor=1 / 16,
                                  mode='bilinear',
                                  align_corners=True)

        self.upsample1 = nn.Upsample(scale_factor=32,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=16,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8,
                                     mode='bilinear',
                                     align_corners=True)

        self.out1_1 = BasicConv2d(channel * 2, channel, 1)
        self.out1_2 = BasicConv2d(channel * 2, channel, 1)
        self.out1_3 = BasicConv2d(channel * 2, channel, 1)
        self.out1_4 = BasicConv2d(channel, channel, 1)

        self.refineconv = BasicConv2d(3, 1, 1)
        self.refine = RefineUNet(1, 1)

        self.ca1 = ChannelAttention(channel)
        self.sa4 = SpatialAttention()
        self.outatte = nn.Conv2d(channel, channel, 1)

        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder4 = DecoderBlock(in_channels=channel, out_channels=channel)
        self.decoder3 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)
        self.decoder2 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)
        self.decoder1 = nn.Sequential(BasicConv2d(channel * 2, channel, 1),
                                      BasicConv2d(channel, channel, 1))

        self.decoder5 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)
        self.decoder6 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)
        self.decoder7 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)
        self.decoder8 = nn.Sequential(BasicConv2d(channel * 2, channel, 1),
                                      BasicConv2d(channel, channel, 1))

        self.encoder2 = EncooderBlock(in_channels=channel * 2,
                                      out_channels=channel)
        self.encoder3 = EncooderBlock(in_channels=channel * 2,
                                      out_channels=channel)
        self.encoder4 = EncooderBlock(in_channels=channel * 2,
                                      out_channels=channel)

        self.unetout1 = nn.Conv2d(channel, 1, 1)
        self.unetout2 = nn.Conv2d(channel, 1, 1)
        self.detailout = nn.Conv2d(channel * 2, 1, 1)

        self.cobv1 = BasicConv2d(3 * channel, channel, 1)
        self.cobv2 = BasicConv2d(3 * channel, channel, 1)
        self.selayer = SELayer(channel)
        self.upconv1 = BasicConv2d(channel, channel, 1)

        self.noncal = BCA(channel, channel, 16)
        self.duatt = _DAHead(channel)
        self.ca = ChannelAttention(channel * 2)
        self.sa = SpatialAttention()

        self.conv = nn.Sequential(BasicConv2d(channel * 3, channel * 2, 1),
                                  BasicConv2d(channel * 2, channel * 2, 1),
                                  BasicConv2d(channel * 2, channel, 1))
        self.edgeconv = BasicConv2d(channel, channel, 1)
        self.downconv = BasicConv2d(channel * 2, channel, 1)
        self.catt1 = CasAtt(channel)
        self.catt2 = CasAtt(channel)