def on_build(self, in_ch, out_ch): self.conv = nn.Conv2D(in_ch, out_ch, kernel_size=3, padding='SAME') self.frn = nn.FRNorm2D(out_ch) self.tlu = nn.TLU(out_ch)
def on_build(self, in_ch, base_ch, out_ch=None): self.convs = [ nn.Conv2D( in_ch, base_ch, kernel_size=7, strides=1, padding='SAME'), nn.Conv2D( base_ch, base_ch, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch, base_ch*2, kernel_size=3, strides=2, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*2, base_ch*2, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*2, base_ch*4, kernel_size=3, strides=2, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*4, base_ch*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*4, base_ch*8, kernel_size=3, strides=2, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*8, base_ch*8, kernel_size=3, strides=1, use_bias=False, padding='SAME') ] self.frns = [ None, nn.FRNorm2D(base_ch), nn.FRNorm2D(base_ch*2), nn.FRNorm2D(base_ch*2), nn.FRNorm2D(base_ch*4), nn.FRNorm2D(base_ch*4), nn.FRNorm2D(base_ch*8), nn.FRNorm2D(base_ch*8), ] self.tlus = [ nn.TLU(base_ch), nn.TLU(base_ch), nn.TLU(base_ch*2), nn.TLU(base_ch*2), nn.TLU(base_ch*4), nn.TLU(base_ch*4), nn.TLU(base_ch*8), nn.TLU(base_ch*8), ] if out_ch is not None: self.out_conv = nn.Conv2D( base_ch*8, out_ch, kernel_size=1, strides=1, use_bias=False, padding='VALID') else: self.out_conv = None
def on_build(self, lmrks_ch, base_ch, out_ch): self.convs = [ nn.Conv2D( base_ch*8+lmrks_ch, base_ch*8, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*8, base_ch*8*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*8, base_ch*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*4, base_ch*4*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*4, base_ch*2, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*2, base_ch*2*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'), nn.Conv2D( base_ch*2, base_ch, kernel_size=3, strides=1, use_bias=False, padding='SAME'), ] self.frns = [ nn.FRNorm2D(base_ch*8), nn.FRNorm2D(base_ch*8*4), nn.FRNorm2D(base_ch*4), nn.FRNorm2D(base_ch*4*4), nn.FRNorm2D(base_ch*2), nn.FRNorm2D(base_ch*2*4), nn.FRNorm2D(base_ch), ] self.tlus = [ nn.TLU(base_ch*8), nn.TLU(base_ch*8*4), nn.TLU(base_ch*4), nn.TLU(base_ch*4*4), nn.TLU(base_ch*2), nn.TLU(base_ch*2*4), nn.TLU(base_ch), ] self.use_upscale = [ False, True, False, True, False, True, False, ] self.out_conv = nn.Conv2D( base_ch, out_ch, kernel_size=3, strides=1, padding='SAME')