Ejemplo n.º 1
0
    def __init__(self, channel=32):
        super(MyNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        path = 'F:\pretrain\pvt_v2_b3.pth'
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {
            k: v
            for k, v in save_model.items() if k in model_dict.keys()
        }
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)

        self.Translayer1 = BasicConv2d(64, channel, 1)
        self.Translayer2 = BasicConv2d(128, channel, 1)
        self.Translayer3 = BasicConv2d(320, channel, 1)
        self.Translayer4 = BasicConv2d(512, channel, 1)

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)

        self.down01 = nn.Upsample(scale_factor=1 / 2,
                                  mode='bilinear',
                                  align_corners=True)
        self.down02 = nn.Upsample(scale_factor=1 / 4,
                                  mode='bilinear',
                                  align_corners=True)
        self.down03 = nn.Upsample(scale_factor=1 / 8,
                                  mode='bilinear',
                                  align_corners=True)
        self.down04 = nn.Upsample(scale_factor=1 / 16,
                                  mode='bilinear',
                                  align_corners=True)

        self.upsample1 = nn.Upsample(scale_factor=32,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=16,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8,
                                     mode='bilinear',
                                     align_corners=True)

        self.out1_1 = BasicConv2d(channel * 3, channel, 1)
        self.out1_2 = BasicConv2d(channel * 3, channel, 1)
        self.out1_3 = BasicConv2d(channel * 3, channel, 1)
        self.out1_4 = BasicConv2d(channel, channel, 1)

        self.refineconv = BasicConv2d(3, 1, 1)
        self.refine = RefineUNet(1, 1)

        self.ca = ChannelAttention(channel)
        self.sa = SpatialAttention()
        self.outatte = nn.Conv2d(channel, channel, 1)

        # Decoder
        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder4 = DecoderBlock(in_channels=channel, out_channels=channel)
        self.decoder3 = DecoderBlock(in_channels=channel * 3,
                                     out_channels=channel)
        self.decoder2 = DecoderBlock(in_channels=channel * 3,
                                     out_channels=channel)
        self.decoder1 = nn.Sequential(BasicConv2d(channel * 3, channel, 1),
                                      BasicConv2d(channel, channel, 1))

        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder2_4 = DecoderBlock(in_channels=channel,
                                       out_channels=channel)
        self.decoder2_3 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=channel)
        self.decoder2_2 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=channel)
        self.decoder2_1 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=channel)

        # adaptive selection module
        self.asm4 = ASM(channel, channel * 3)
        self.asm3 = ASM(channel, channel * 3)
        self.asm2 = ASM(channel, channel * 3)
        self.asm1 = ASM(channel, channel * 3)

        self.unetout1 = nn.Conv2d(channel, 1, 1)
        self.unetout2 = nn.Conv2d(channel, 1, 1)

        self.COM = COM(channel)
        self.cobv1 = BasicConv2d(3 * channel, channel, 1)
        self.cobv2 = BasicConv2d(3 * channel, channel, 1)
        self.nocal = NonLocalBlock(channel)
        self.selayer = SELayer(channel)
Ejemplo n.º 2
0
    def __init__(self, channel=64):
        super(MyNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        path = 'F:\pretrain\pvt_v2_b3.pth'
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {
            k: v
            for k, v in save_model.items() if k in model_dict.keys()
        }
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)

        self.Translayer2_0 = BasicConv2d(64, channel, 1)
        self.Translayer2_1 = BasicConv2d(128, channel, 1)
        self.Translayer3_1 = BasicConv2d(320, channel, 1)
        self.Translayer4_1 = BasicConv2d(512, channel, 1)

        # self.CFM = CFM(channel)
        #
        # self.SAM = SAM()

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)

        self.down05 = nn.Upsample(scale_factor=0.5,
                                  mode='bilinear',
                                  align_corners=True)
        self.out_SAM = nn.Conv2d(channel, 1, 1)
        self.out_CFM = nn.Conv2d(channel, 1, 1)

        self.rf1 = RF(512 + 320 + 128, channel)
        self.rf2 = RF(512 + 320 + 128, channel)
        self.rf3 = RF(512 + 320 + 128, channel)

        self.att1 = BCA(channel, channel, channel)
        self.att2 = BCA(channel, channel, channel)
        self.att3 = BCA(channel, channel, channel)

        self.out_SAM = nn.Conv2d(channel, 1, 1)
        self.out_CFM = nn.Conv2d(channel, 1, 1)

        self.down01 = nn.Upsample(scale_factor=1 / 8,
                                  mode='bilinear',
                                  align_corners=True)
        self.down02 = nn.Upsample(scale_factor=1 / 4,
                                  mode='bilinear',
                                  align_corners=True)
        self.down03 = nn.Upsample(scale_factor=1 / 2,
                                  mode='bilinear',
                                  align_corners=True)

        self.upsample1 = nn.Upsample(scale_factor=32,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=16,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8,
                                     mode='bilinear',
                                     align_corners=True)

        self.out1 = BasicConv2d(channel, 1, 1)
        self.out2 = BasicConv2d(channel, 1, 1)
        self.out3 = BasicConv2d(channel, 1, 1)
        self.out4 = BasicConv2d(channel, 1, 1)
        self.refineconv = BasicConv2d(3, 1, 1)
        self.refine = RefineUNet(1, 1)

        self.ca = ChannelAttention(channel)
        self.sa = SpatialAttention()
        self.outatte = nn.Conv2d(channel, channel, 1)

        self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample5 = BasicConv2d(2 * channel,
                                          2 * channel,
                                          3,
                                          padding=1)
Ejemplo n.º 3
0
    def __init__(self, channel=64):
        super(MyNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        path = 'F:\pretrain\pvt_v2_b3.pth'
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {
            k: v
            for k, v in save_model.items() if k in model_dict.keys()
        }
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)

        self.Translayer1 = BasicConv2d(64, channel, 1)
        self.Translayer2 = BasicConv2d(128, channel, 1)
        self.Translayer3 = BasicConv2d(320, channel, 1)
        self.Translayer4 = BasicConv2d(512, channel, 1)

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)

        self.down05 = nn.Upsample(scale_factor=0.5,
                                  mode='bilinear',
                                  align_corners=True)

        self.down01 = nn.Upsample(scale_factor=1 / 2,
                                  mode='bilinear',
                                  align_corners=True)
        self.down02 = nn.Upsample(scale_factor=1 / 4,
                                  mode='bilinear',
                                  align_corners=True)
        self.down03 = nn.Upsample(scale_factor=1 / 8,
                                  mode='bilinear',
                                  align_corners=True)
        self.down04 = nn.Upsample(scale_factor=1 / 16,
                                  mode='bilinear',
                                  align_corners=True)

        self.upsample1 = nn.Upsample(scale_factor=32,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=16,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8,
                                     mode='bilinear',
                                     align_corners=True)

        self.out1_1 = BasicConv2d(64, channel, 1)
        self.out1_2 = BasicConv2d(128, channel, 1)
        self.out1_3 = BasicConv2d(320, channel, 1)
        self.out1_4 = BasicConv2d(512, channel, 1)

        self.out2_1 = BasicConv2d(64, 1, 1)
        self.out2_2 = BasicConv2d(128, 1, 1)
        self.out2_3 = BasicConv2d(320, 1, 1)
        self.out2_4 = BasicConv2d(512, 1, 1)

        self.refineconv = BasicConv2d(3, 1, 1)
        self.refine = RefineUNet(1, 1)

        self.ca = ChannelAttention(channel)
        self.sa = SpatialAttention()
        self.outatte = nn.Conv2d(channel, channel, 1)

        # Decoder
        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        #  self.decoder4 = DecoderBlock(in_channels=channel, out_channels=channel)
        #  self.decoder3 = DecoderBlock(in_channels=channel*2, out_channels=channel)
        #  self.decoder2 = DecoderBlock(in_channels=channel*2, out_channels=channel)
        #  self.decoder1 = DecoderBlock(in_channels=channel*2, out_channels=channel)
        #

        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder2_4 = DecoderBlock(in_channels=channel, out_channels=320)
        self.decoder2_3 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=128)
        self.decoder2_2 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=64)
        self.decoder2_1 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=64)

        # adaptive selection module
        self.asm4 = ASM(channel, 512)
        self.asm3 = ASM(channel, channel * 3)
        self.asm2 = ASM(channel, channel * 3)
        self.asm1 = ASM(channel, channel * 3)

        self.unetout1 = nn.Sequential(
            BasicConv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(32, 1, 1))
        self.unetout1 = nn.Sequential(
            BasicConv2d(channel, channel, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(channel, 1, 1))

        # Sideout
        self.sideout4 = SideoutBlock(32, 1)
        self.sideout3 = SideoutBlock(32, 1)
        self.sideout2 = SideoutBlock(32, 1)
        self.sideout1 = SideoutBlock(32, 1)

        self.decoder1 = Decoder()
        self.decoder2 = Decoder()

        self.acfm3 = ACFM()
        self.acfm2 = ACFM()

        self.dgcm3 = DGCM()
        self.dgcm2 = DGCM()
        self.upconv3 = BasicConv2d(64,
                                   64,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   relu=True)

        self.classifier = nn.Conv2d(64, 1, 1)
        self.upconv2 = BasicConv2d(64,
                                   64,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   relu=True)
Ejemplo n.º 4
0
    def __init__(self, channel=64):
        super(MyNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        self.resnet = res2net50_v1b_26w_4s(pretrained=True)

        self.Translayer1 = BasicConv2d(256, channel, 1)
        self.Translayer2 = BasicConv2d(512, channel, 1)
        self.Translayer3 = BasicConv2d(1024, channel, 1)
        self.Translayer4 = BasicConv2d(2048, channel, 1)

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)

        self.down05 = nn.Upsample(scale_factor=0.5,
                                  mode='bilinear',
                                  align_corners=True)

        self.down01 = nn.Upsample(scale_factor=1 / 2,
                                  mode='bilinear',
                                  align_corners=True)
        self.down02 = nn.Upsample(scale_factor=1 / 4,
                                  mode='bilinear',
                                  align_corners=True)
        self.down03 = nn.Upsample(scale_factor=1 / 8,
                                  mode='bilinear',
                                  align_corners=True)
        self.down04 = nn.Upsample(scale_factor=1 / 16,
                                  mode='bilinear',
                                  align_corners=True)

        self.upsample1 = nn.Upsample(scale_factor=32,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=16,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8,
                                     mode='bilinear',
                                     align_corners=True)

        self.out1_1 = BasicConv2d(channel * 3, channel, 1)
        self.out1_2 = BasicConv2d(channel * 3, channel, 1)
        self.out1_3 = BasicConv2d(channel * 3, channel, 1)
        self.out1_4 = BasicConv2d(channel * 3, channel, 1)

        self.out2_1 = BasicConv2d(256, 1, 1)
        self.out2_2 = BasicConv2d(512, 1, 1)
        self.out2_3 = BasicConv2d(1024, 1, 1)
        self.out2_4 = BasicConv2d(512, 1, 1)

        self.refineconv = BasicConv2d(3, 1, 1)
        self.refine = RefineUNet(1, 1)

        self.ca = ChannelAttention(channel)
        self.sa = SpatialAttention()
        self.outatte = nn.Conv2d(channel, channel, 1)

        # Decoder
        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder4 = DecoderBlock(in_channels=channel * 3,
                                     out_channels=channel)
        self.decoder3 = DecoderBlock(in_channels=channel * 3,
                                     out_channels=channel)
        self.decoder2 = DecoderBlock(in_channels=channel * 3,
                                     out_channels=channel)
        self.decoder1 = nn.Sequential(BasicConv2d(channel * 2, channel, 1),
                                      BasicConv2d(channel, channel, 1))

        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder2_4 = DecoderBlock(in_channels=channel,
                                       out_channels=channel)
        self.decoder2_3 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=channel)
        self.decoder2_2 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=channel)
        self.decoder2_1 = DecoderBlock(in_channels=channel * 2,
                                       out_channels=channel)

        # adaptive selection module
        self.asm4 = ASM(channel, channel * 3)
        self.asm3 = ASM(channel, channel * 3)
        self.asm2 = ASM(channel, channel * 3)
        self.asm1 = ASM(channel, channel * 3)

        self.unetout1 = nn.Conv2d(channel, 1, 1)
        self.unetout2 = nn.Conv2d(channel, 1, 1)

        self.COM = COM(channel)
        self.cobv1 = BasicConv2d(3 * channel, channel, 1)
        self.cobv2 = BasicConv2d(3 * channel, channel, 1)
        self.nocal = NonLocalBlock(channel)
Ejemplo n.º 5
0
    def __init__(self, channel=64):
        super(MyNet, self).__init__()

        self.resnet = res2net50_v1b_26w_4s(pretrained=True)

        self.Translayer1= BasicConv2d(64, channel, 1)
        self.Translayer2 = BasicConv2d(128, channel, 1)
        self.Translayer3 = BasicConv2d(320, channel, 1)
        self.Translayer4 = BasicConv2d(512, channel, 1)
        self.Translayerup1 = BasicConv2d(256 ,channel, 1)
        self.Translayerup2 = BasicConv2d(512, channel, 1)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.downsample = nn.Upsample(scale_factor=1/4, mode='bilinear', align_corners=True)

        self.down01 = nn.Upsample(scale_factor=1 / 2, mode='bilinear', align_corners=True)
        self.down02 = nn.Upsample(scale_factor=1/4, mode='bilinear', align_corners=True)
        self.down03 = nn.Upsample(scale_factor=1/8, mode='bilinear', align_corners=True)
        self.down04 = nn.Upsample(scale_factor=1/16, mode='bilinear', align_corners=True)

        self.upsample1 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)

        self.out1_1 =  BasicConv2d(channel*2, channel, 1)
        self.out1_2 =  BasicConv2d(channel*2, channel, 1)
        self.out1_3 =   BasicConv2d(channel*2, channel, 1)
        self.out1_4 =   BasicConv2d(channel, channel, 1)


        self.refineconv =  BasicConv2d(3, 1, 1)
        self.refine = RefineUNet(1,1)

        self.ca1 = ChannelAttention(channel)
        self.sa4= SpatialAttention()
        self.outatte = nn.Conv2d(channel, channel, 1)

       # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder4 = DecoderBlock(in_channels=channel, out_channels=channel)
        self.decoder3 = DecoderBlock(in_channels=channel*2, out_channels=channel)
        self.decoder2 = DecoderBlock(in_channels=channel*2, out_channels=channel)
        self.decoder1 =  nn.Sequential(BasicConv2d(channel*2, channel,1),
                                 BasicConv2d(channel, channel,1))


        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder2_4 = DecoderBlock(in_channels=channel, out_channels=channel)
        self.decoder2_3 = DecoderBlock(in_channels=channel*2, out_channels=channel)
        self.decoder2_2 = DecoderBlock(in_channels=channel*2, out_channels=channel)
        self.decoder2_1 = DecoderBlock(in_channels=channel*2, out_channels=channel)




        self.unetout1 =  nn.Conv2d(channel, 1, 1)
        self.unetout2 =  nn.Conv2d(channel, 1, 1)
        self.detailout =  nn.Conv2d(channel*2, 1, 1)
        self.com =COM(channel)


        self.cobv1 =BasicConv2d(3*channel,channel,1)
        self.cobv2 =BasicConv2d(3*channel,channel,1)
        self.nocal = NonLocalBlock(channel)
        self.selayer = SELayer(channel)
        self.upconv1 =BasicConv2d(channel,channel,1)

        self.noncal = BCA(channel, channel, 16)
        self.duatt =_DAHead(channel)
        self.ca = ChannelAttention(channel)
        self.sa = SpatialAttention()

        self.conv = nn.Sequential(
            BasicConv2d(channel*2,channel*2,1),
            BasicConv2d(channel*2,channel,1)
        )
        self.edgeconv =BasicConv2d(channel,channel,1)
        self.downconv =BasicConv2d(channel,channel,1)
        self.rfb2_1 = RFB_modified(512, channel)
        self.rfb1_1 = RFB_modified(256, channel)
        self.rfb3_1 = RFB_modified(1024, channel)
        self.rfb4_1 = RFB_modified(2048, channel)
Ejemplo n.º 6
0
    def __init__(self, channel=32):
        super(MyNet, self).__init__()

        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        path = 'F:\pretrain\pvt_v2_b3.pth'
        save_model = torch.load(path)
        model_dict = self.backbone.state_dict()
        state_dict = {
            k: v
            for k, v in save_model.items() if k in model_dict.keys()
        }
        model_dict.update(state_dict)
        self.backbone.load_state_dict(model_dict)

        self.Translayer1 = BasicConv2d(64, channel, 1)
        self.Translayer2 = BasicConv2d(128, channel, 1)
        self.Translayer3 = BasicConv2d(320, channel, 1)
        self.Translayer4 = BasicConv2d(512, channel, 1)
        self.Translayer5 = BasicConv2d(64, channel, 1)
        self.Translayer6 = BasicConv2d(128, channel, 1)
        self.Translayer7 = BasicConv2d(320, channel, 1)
        self.Translayer8 = BasicConv2d(512, channel, 1)

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)
        self.downsample = nn.Upsample(scale_factor=1 / 4,
                                      mode='bilinear',
                                      align_corners=True)

        self.down01 = nn.Upsample(scale_factor=1 / 2,
                                  mode='bilinear',
                                  align_corners=True)
        self.down02 = nn.Upsample(scale_factor=1 / 4,
                                  mode='bilinear',
                                  align_corners=True)
        self.down03 = nn.Upsample(scale_factor=1 / 8,
                                  mode='bilinear',
                                  align_corners=True)
        self.down04 = nn.Upsample(scale_factor=1 / 16,
                                  mode='bilinear',
                                  align_corners=True)

        self.upsample1 = nn.Upsample(scale_factor=32,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=16,
                                     mode='bilinear',
                                     align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=8,
                                     mode='bilinear',
                                     align_corners=True)

        self.out1_1 = BasicConv2d(channel * 2, channel, 1)
        self.out1_2 = BasicConv2d(channel * 2, channel, 1)
        self.out1_3 = BasicConv2d(channel * 2, channel, 1)
        self.out1_4 = BasicConv2d(channel, channel, 1)

        self.refineconv = BasicConv2d(3, 1, 1)
        self.refine = RefineUNet(1, 1)

        self.ca1 = ChannelAttention(channel)
        self.sa4 = SpatialAttention()
        self.outatte = nn.Conv2d(channel, channel, 1)

        # self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder4 = DecoderBlock(in_channels=channel, out_channels=channel)
        self.decoder3 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)
        self.decoder2 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)
        self.decoder1 = nn.Sequential(BasicConv2d(channel * 2, channel, 1),
                                      BasicConv2d(channel, channel, 1))

        self.decoder5 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)
        self.decoder6 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)
        self.decoder7 = DecoderBlock(in_channels=channel * 2,
                                     out_channels=channel)
        self.decoder8 = nn.Sequential(BasicConv2d(channel * 2, channel, 1),
                                      BasicConv2d(channel, channel, 1))

        self.encoder2 = EncooderBlock(in_channels=channel * 2,
                                      out_channels=channel)
        self.encoder3 = EncooderBlock(in_channels=channel * 2,
                                      out_channels=channel)
        self.encoder4 = EncooderBlock(in_channels=channel * 2,
                                      out_channels=channel)

        self.unetout1 = nn.Conv2d(channel, 1, 1)
        self.unetout2 = nn.Conv2d(channel, 1, 1)
        self.detailout = nn.Conv2d(channel * 2, 1, 1)

        self.cobv1 = BasicConv2d(3 * channel, channel, 1)
        self.cobv2 = BasicConv2d(3 * channel, channel, 1)
        self.selayer = SELayer(channel)
        self.upconv1 = BasicConv2d(channel, channel, 1)

        self.noncal = BCA(channel, channel, 16)
        self.duatt = _DAHead(channel)
        self.ca = ChannelAttention(channel * 2)
        self.sa = SpatialAttention()

        self.conv = nn.Sequential(BasicConv2d(channel * 3, channel * 2, 1),
                                  BasicConv2d(channel * 2, channel * 2, 1),
                                  BasicConv2d(channel * 2, channel, 1))
        self.edgeconv = BasicConv2d(channel, channel, 1)
        self.downconv = BasicConv2d(channel * 2, channel, 1)
        self.catt1 = CasAtt(channel)
        self.catt2 = CasAtt(channel)