Example #1
0
 def __init__(self):
     super(net3,self).__init__()     
     with self.name_scope():
         encoder1=encoder(3,16)
         encoder2=encoder(16,32)
         encoder3=encoder(32,64)
         decoder1=decoder(64,32)
         decoder2=decoder(32,16)
         decoder3=decoder(16,1)
         att2=CA_M2(32)
         att3=CA_M2(64)
         att4=CA_M2(32)
     blocks=[encoder1,encoder2,att2,encoder3,att3,decoder1,att4,decoder2,decoder3]
     self.net1=HybridSequential()
     with self.net1.name_scope():  
         for block in blocks:
             self.net1.add(block)
Example #2
0
 def __init__(self):
     super(net2,self).__init__()
     self.net=HybridSequential()       
     with self.net.name_scope():
         
         self.net.add(encoder(3,16))
         self.net.add(encoder(16,32))
     self.att= CA_M2(32)
     self.net1=HybridSequential()
     with self.net1.name_scope():      
         self.net1.add(encoder(32,64))
         self.net1.add(decoder(64,32))
         self.net1.add(decoder(32,16))
         self.net1.add(decoder(16,1))
Example #3
0
    def __init__(self,
                 inner_channels,
                 outer_channels,
                 inner_block=None,
                 innermost=False,
                 outermost=False,
                 use_dropout=False,
                 use_bias=False,
                 use_attention=True,
                 use_resblock=True,
                 use_p_at=False,
                 use_c_at=False,
                 save_att=False):
        super(UnetSkipUnit, self).__init__()

        with self.name_scope():
            self.save_att = save_att
            self.outermost = outermost
            self.innermost = innermost
            self.use_attention = use_attention
            if not self.outermost:
                res_block_1 = Res_Block(outer_channels=outer_channels)
                res_block_2 = Res_Block(outer_channels=inner_channels)
            en_conv = Conv2D(channels=inner_channels,
                             kernel_size=4,
                             strides=2,
                             padding=1,
                             in_channels=outer_channels,
                             use_bias=use_bias)
            en_relu = LeakyReLU(alpha=0.2)
            en_norm = BatchNorm(momentum=0.1, in_channels=inner_channels)

            de_relu = Activation(activation='relu')
            de_norm = BatchNorm(momentum=0.1, in_channels=outer_channels)

            if innermost:
                de_conv = Conv2DTranspose(channels=outer_channels,
                                          kernel_size=4,
                                          strides=2,
                                          padding=1,
                                          in_channels=inner_channels,
                                          use_bias=use_bias)
                if use_p_at:

                    self.p_at = CA_M2(in_channel=inner_channels)
                else:
                    self.p_at = CA_M3()
                if use_c_at:
                    self.c_at = CA_M1()
                else:
                    self.c_at = CA_M3()
                res_block_3 = Res_Block(outer_channels=inner_channels)
                res_block_4 = Res_Block(outer_channels=outer_channels)
                if use_resblock:
                    res1 = res_block_1
                    encoder = [en_conv, en_norm, en_relu]
                    res2 = res_block_2
                    res3 = res_block_3
                    decoder = [de_conv, de_norm, de_relu]
                    res4 = res_block_4
                else:
                    encoder = [en_relu, en_conv]
                    decoder = [de_relu, de_conv, de_norm]

            elif outermost:
                de_conv = Conv2DTranspose(channels=outer_channels,
                                          kernel_size=4,
                                          strides=2,
                                          padding=1,
                                          in_channels=inner_channels)
                channel_trans = Conv2D(channels=1,
                                       in_channels=outer_channels,
                                       kernel_size=1,
                                       prefix='')

                if use_resblock:
                    res1 = None
                    encoder = [en_conv, en_norm, en_relu]
                    res2 = None
                    res3 = None
                    decoder = [de_conv, de_norm, de_relu, channel_trans]
                    res4 = None
                else:

                    encoder = [en_conv]
                    decoder = [de_relu, de_conv, de_norm, channel_trans]

                if use_p_at:
                    self.p_at = CA_M2(in_channel=inner_channels)
                else:
                    self.p_at = CA_M3()
                if use_c_at:
                    self.c_at = CA_M1()
                else:
                    self.c_at = CA_M3()

            else:
                de_conv = Conv2DTranspose(channels=outer_channels,
                                          kernel_size=4,
                                          strides=2,
                                          padding=1,
                                          in_channels=inner_channels,
                                          use_bias=use_bias)

                if use_p_at:
                    self.p_at = CA_M2(in_channel=inner_channels)
                else:
                    self.p_at = CA_M3()
                if use_c_at:
                    self.c_at = CA_M1()
                else:
                    self.c_at = CA_M3()

                res_block_3 = Res_Block(outer_channels=inner_channels)
                res_block_4 = Res_Block(outer_channels=outer_channels)

                if use_resblock:
                    res1 = res_block_1
                    encoder = [en_conv, en_norm, en_relu]
                    res2 = res_block_2
                    res3 = res_block_3
                    decoder = [de_conv, de_norm, de_relu]
                    res4 = res_block_4
                else:
                    encoder = [en_relu, en_conv, en_norm]
                    decoder = [de_relu, de_conv, de_norm]

            if use_dropout:
                decoder += [Dropout(rate=0.5)]

            self.encoder = HybridSequential()
            with self.encoder.name_scope():
                for block in encoder:
                    self.encoder.add(block)

            self.inner_block = inner_block

            self.res1 = res1
            self.res2 = res2
            self.res3 = res3
            self.res4 = res4

            self.decoder = HybridSequential()

            with self.decoder.name_scope():
                for block in decoder:
                    self.decoder.add(block)
Example #4
0
    def __init__(self, inner_channels, outer_channels, inner_block=None, innermost=False, outermost=False,
                 use_dropout=False, use_bias=False,use_position_attention=False,use_channel_attention=False):
        super(UnetSkipUnit, self).__init__()

        with self.name_scope():
            self.outermost = outermost

            res1=Res_Block(outer_channels=outer_channels)
            res2=Res_Block(outer_channels=inner_channels)
            res3=Res_Block(outer_channels=inner_channels)
            res4=Res_Block(outer_channels=outer_channels)
            attention_non=CA_M3()
            attention_position=CA_M2(in_channel=inner_channels)
            attention_channel=CA_M1()
            
            en_conv = Conv2D(channels=inner_channels, kernel_size=4, strides=2, padding=1,
                             in_channels=outer_channels, use_bias=use_bias)
            en_relu = LeakyReLU(alpha=0.2)
            en_norm = BatchNorm(momentum=0.1, in_channels=inner_channels)
            de_relu = Activation(activation='relu')
            de_norm = BatchNorm(momentum=0.1, in_channels=outer_channels)

            if innermost:
                if use_position_attention:
                    p_attention=attention_position
                else:
                    p_attention=attention_non
                if use_channel_attention:
                    c_attention=attention_channel
                else:
                    c_attention=attention_non
                de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,
                                          in_channels=inner_channels, use_bias=use_bias)
                encoder = [res1,en_relu, en_conv,p_attention,res2]
                decoder = [res3,de_relu, de_conv, de_norm,res4]
                model = encoder + decoder
            elif outermost:
                if use_position_attention:
                    p_attention=attention_position
                else:
                    p_attention=attention_non
                if use_channel_attention:
                    c_attention=attention_channel
                else:
                    c_attention=attention_non
                de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,
                                          in_channels=inner_channels)
                encoder = [en_conv,p_attention]
                decoder = [de_relu, de_conv, de_norm]
                model = encoder + [inner_block] + decoder
            else:
                if use_position_attention:
                    p_attention=attention_position
                else:
                    p_attention=attention_non
                if use_channel_attention:
                    c_attention=attention_channel
                else:
                    c_attention=attention_non
                de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,
                                          in_channels=inner_channels, use_bias=use_bias)
                encoder = [res1,en_relu, en_conv, en_norm,p_attention,res2]
                decoder = [res3,de_relu, de_conv, de_norm,res4]
                model = encoder + [inner_block] + decoder
            self.c_attention=c_attention
            self.model = HybridSequential()
            with self.model.name_scope():
                for block in model:
                    self.model.add(block)