def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False):
        self.use_fp16 = use_fp16
        conv_dtype = tf.float16 if use_fp16 else tf.float32 
        
        class ResidualBlock(nn.ModelBase):
            def on_build(self, ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
                self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)

            def forward(self, inp):
                x = self.conv1(inp)
                x = tf.nn.leaky_relu(x, 0.2)
                x = self.conv2(x)
                x = tf.nn.leaky_relu(inp + x, 0.2)
                return x

        prev_ch = in_ch
        self.convs = []
        self.upconvs = []
        layers = self.find_archi(patch_size)
        
        level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }

        self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype)

        for i, (kernel_size, strides) in enumerate(layers):
            self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )

            self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )

        self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype)

        self.center_out  =  nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
        self.center_conv =  nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
Example #2
0
 def on_build(self, in_ch, out_ch):
     self.conv = nn.Conv2DTranspose(in_ch,
                                    out_ch,
                                    kernel_size=3,
                                    padding='SAME')
     self.frn = nn.FRNorm2D(out_ch)
     self.tlu = nn.TLU(out_ch)
Example #3
0
            def on_build(self, in_ch, base_ch):

                self.features_0 = nn.Conv2D(in_ch,
                                            base_ch,
                                            kernel_size=3,
                                            padding='SAME')
                self.blurpool_0 = nn.BlurPool(filt_size=3)

                self.features_3 = nn.Conv2D(base_ch,
                                            base_ch * 2,
                                            kernel_size=3,
                                            padding='SAME')
                self.blurpool_3 = nn.BlurPool(filt_size=3)

                self.features_6 = nn.Conv2D(base_ch * 2,
                                            base_ch * 4,
                                            kernel_size=3,
                                            padding='SAME')
                self.features_8 = nn.Conv2D(base_ch * 4,
                                            base_ch * 4,
                                            kernel_size=3,
                                            padding='SAME')
                self.blurpool_8 = nn.BlurPool(filt_size=3)

                self.features_11 = nn.Conv2D(base_ch * 4,
                                             base_ch * 8,
                                             kernel_size=3,
                                             padding='SAME')
                self.features_13 = nn.Conv2D(base_ch * 8,
                                             base_ch * 8,
                                             kernel_size=3,
                                             padding='SAME')
                self.blurpool_13 = nn.BlurPool(filt_size=3)

                self.features_16 = nn.Conv2D(base_ch * 8,
                                             base_ch * 8,
                                             kernel_size=3,
                                             padding='SAME')
                self.features_18 = nn.Conv2D(base_ch * 8,
                                             base_ch * 8,
                                             kernel_size=3,
                                             padding='SAME')
                self.blurpool_18 = nn.BlurPool(filt_size=3)

                self.conv_center = nn.Conv2D(base_ch * 8,
                                             base_ch * 8,
                                             kernel_size=3,
                                             padding='SAME')

                self.conv1_up = nn.Conv2DTranspose(base_ch * 8,
                                                   base_ch * 4,
                                                   kernel_size=3,
                                                   padding='SAME')
                self.conv1 = nn.Conv2D(base_ch * 12,
                                       base_ch * 8,
                                       kernel_size=3,
                                       padding='SAME')

                self.conv2_up = nn.Conv2DTranspose(base_ch * 8,
                                                   base_ch * 4,
                                                   kernel_size=3,
                                                   padding='SAME')
                self.conv2 = nn.Conv2D(base_ch * 12,
                                       base_ch * 8,
                                       kernel_size=3,
                                       padding='SAME')

                self.conv3_up = nn.Conv2DTranspose(base_ch * 8,
                                                   base_ch * 2,
                                                   kernel_size=3,
                                                   padding='SAME')
                self.conv3 = nn.Conv2D(base_ch * 6,
                                       base_ch * 4,
                                       kernel_size=3,
                                       padding='SAME')

                self.conv4_up = nn.Conv2DTranspose(base_ch * 4,
                                                   base_ch,
                                                   kernel_size=3,
                                                   padding='SAME')
                self.conv4 = nn.Conv2D(base_ch * 3,
                                       base_ch * 2,
                                       kernel_size=3,
                                       padding='SAME')

                self.conv5_up = nn.Conv2DTranspose(base_ch * 2,
                                                   base_ch // 2,
                                                   kernel_size=3,
                                                   padding='SAME')
                self.conv5 = nn.Conv2D(base_ch // 2 + base_ch,
                                       base_ch,
                                       kernel_size=3,
                                       padding='SAME')

                self.out_conv = nn.Conv2D(base_ch,
                                          1,
                                          kernel_size=3,
                                          padding='SAME')