def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False): self.use_fp16 = use_fp16 conv_dtype = tf.float16 if use_fp16 else tf.float32 class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) def forward(self, inp): x = self.conv1(inp) x = tf.nn.leaky_relu(x, 0.2) x = self.conv2(x) x = tf.nn.leaky_relu(inp + x, 0.2) return x prev_ch = in_ch self.convs = [] self.upconvs = [] layers = self.find_archi(patch_size) level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) } self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype) for i, (kernel_size, strides) in enumerate(layers): self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype) self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype) self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
def on_build(self, in_ch, out_ch): self.conv = nn.Conv2DTranspose(in_ch, out_ch, kernel_size=3, padding='SAME') self.frn = nn.FRNorm2D(out_ch) self.tlu = nn.TLU(out_ch)
def on_build(self, in_ch, base_ch): self.features_0 = nn.Conv2D(in_ch, base_ch, kernel_size=3, padding='SAME') self.blurpool_0 = nn.BlurPool(filt_size=3) self.features_3 = nn.Conv2D(base_ch, base_ch * 2, kernel_size=3, padding='SAME') self.blurpool_3 = nn.BlurPool(filt_size=3) self.features_6 = nn.Conv2D(base_ch * 2, base_ch * 4, kernel_size=3, padding='SAME') self.features_8 = nn.Conv2D(base_ch * 4, base_ch * 4, kernel_size=3, padding='SAME') self.blurpool_8 = nn.BlurPool(filt_size=3) self.features_11 = nn.Conv2D(base_ch * 4, base_ch * 8, kernel_size=3, padding='SAME') self.features_13 = nn.Conv2D(base_ch * 8, base_ch * 8, kernel_size=3, padding='SAME') self.blurpool_13 = nn.BlurPool(filt_size=3) self.features_16 = nn.Conv2D(base_ch * 8, base_ch * 8, kernel_size=3, padding='SAME') self.features_18 = nn.Conv2D(base_ch * 8, base_ch * 8, kernel_size=3, padding='SAME') self.blurpool_18 = nn.BlurPool(filt_size=3) self.conv_center = nn.Conv2D(base_ch * 8, base_ch * 8, kernel_size=3, padding='SAME') self.conv1_up = nn.Conv2DTranspose(base_ch * 8, base_ch * 4, kernel_size=3, padding='SAME') self.conv1 = nn.Conv2D(base_ch * 12, base_ch * 8, kernel_size=3, padding='SAME') self.conv2_up = nn.Conv2DTranspose(base_ch * 8, base_ch * 4, kernel_size=3, padding='SAME') self.conv2 = nn.Conv2D(base_ch * 12, base_ch * 8, kernel_size=3, padding='SAME') self.conv3_up = nn.Conv2DTranspose(base_ch * 8, base_ch * 2, kernel_size=3, padding='SAME') self.conv3 = nn.Conv2D(base_ch * 6, base_ch * 4, kernel_size=3, padding='SAME') self.conv4_up = nn.Conv2DTranspose(base_ch * 4, base_ch, kernel_size=3, padding='SAME') self.conv4 = nn.Conv2D(base_ch * 3, base_ch * 2, kernel_size=3, padding='SAME') self.conv5_up = nn.Conv2DTranspose(base_ch * 2, base_ch // 2, kernel_size=3, padding='SAME') self.conv5 = nn.Conv2D(base_ch // 2 + base_ch, base_ch, kernel_size=3, padding='SAME') self.out_conv = nn.Conv2D(base_ch, 1, kernel_size=3, padding='SAME')