示例#1
0
 def forward(self, inp):
     if self.is_hd:
         x = tf.concat([ nn.flatten(self.down1(inp)),
                         nn.flatten(self.down2(inp)),
                         nn.flatten(self.down3(inp)),
                         nn.flatten(self.down4(inp)) ], -1 )
     else:
         x = nn.flatten(self.down1(inp))
     return x
 def forward(self, x):
     if use_fp16:
         x = tf.cast(x, tf.float16)
     x = nn.flatten(self.down1(x))
     if use_fp16:
         x = tf.cast(x, tf.float32)
     return x
示例#3
0
            def forward(self, inp):
                x = inp

                x = self.down11(x)
                x = self.down12(x)
                x = nn.max_pool(x)

                x = self.down21(x)
                x = self.down22(x)
                x = nn.max_pool(x)

                x = self.down31(x)
                x = self.down32(x)
                x = self.down33(x)
                x = nn.max_pool(x)

                x = self.down41(x)
                x = self.down42(x)
                x = self.down43(x)
                x = nn.max_pool(x)

                x = self.down51(x)
                x = self.down52(x)
                x = self.down53(x)
                x = nn.max_pool(x)

                x = nn.flatten(x)
                return x
示例#4
0
 def forward(self, inp):
     x = inp
     x = self.down1(x)
     x = self.res1(x)
     x = self.down2(x)
     x = self.down3(x)
     x = self.down4(x)
     x = self.down5(x)
     x = self.res5(x)
     x = nn.flatten(x)
     x = nn.pixel_norm(x, axes=-1)
     x = self.dense1(x)
     return x
示例#5
0
            def forward(self, inp, stage):
                x = inp

                for level in range(stage, -1, -1):
                    if stage in self.enc_blocks:
                        if level == stage:
                            x = self.from_rgbs[level](x)
                        x = self.enc_blocks[level](x)

                x = nn.flatten(x)
                x = self.dense_norm(x)
                x = nn.reshape_4D(x, 1, 1, self.max_ch)

                return x
示例#6
0
 def forward(self, x):
     if use_fp16:
         x = tf.cast(x, tf.float16)
     x = self.down1(x)
     x = self.res1(x)
     x = self.down2(x)
     x = self.down3(x)
     x = self.down4(x)
     x = self.down5(x)
     x = self.res5(x)
     if use_fp16:
         x = tf.cast(x, tf.float32)
     x = nn.pixel_norm(nn.flatten(x), axes=-1)
     x = self.dense1(x)
     return x
示例#7
0
            def forward(self, stage, inp, prev_inp=None, alpha=None):
                x = inp

                for level in range(stage, -1, -1):
                    if stage in self.from_rgbs:
                        if level == stage:
                            x = self.from_rgbs[level](x)
                        elif level == stage - 1:
                            x = x * alpha + self.from_rgbs[level](prev_inp) * (
                                1 - alpha)

                        if level != 0:
                            x = self.enc_blocks[level](x)

                x = nn.flatten(x)
                x = self.dense_norm(x)
                x = self.ae_dense1(x)
                x = self.ae_dense2(x)
                x = nn.reshape_4D(x, ae_res, ae_res, ae_ch)

                return x
示例#8
0
                def forward(self, x):
                    if use_fp16:
                        x = tf.cast(x, tf.float16)

                    if 't' in opts:
                        x = self.down1(x)
                        x = self.res1(x)
                        x = self.down2(x)
                        x = self.down3(x)
                        x = self.down4(x)
                        x = self.down5(x)
                        x = self.res5(x)
                    else:
                        x = self.down1(x)
                    x = nn.flatten(x)
                    if 'u' in opts:
                        x = nn.pixel_norm(x, axes=-1)

                    if use_fp16:
                        x = tf.cast(x, tf.float32)
                    return x
示例#9
0
 def forward(self, inp):
     return nn.flatten(self.down1(inp))
示例#10
0
 def forward(self, inp):
     x = nn.flatten(self.down1(inp))
     return x
示例#11
0
    def forward(self, inp):
        x = inp

        x = self.conv01(x)
        x = x0 = self.conv02(x)
        x = self.bp0(x)

        x = self.conv11(x)
        x = x1 = self.conv12(x)
        x = self.bp1(x)

        x = self.conv21(x)
        x = x2 = self.conv22(x)
        x = self.bp2(x)

        x = self.conv31(x)
        x = self.conv32(x)
        x = x3 = self.conv33(x)
        x = self.bp3(x)

        x = self.conv41(x)
        x = self.conv42(x)
        x = x4 = self.conv43(x)
        x = self.bp4(x)

        x = self.conv51(x)
        x = self.conv52(x)
        x = x5 = self.conv53(x)
        x = self.bp5(x)

        x = nn.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = nn.reshape_4D(x, 4, 4, self.base_ch * 8)

        x = self.up5(x)
        x = self.uconv53(tf.concat([x, x5], axis=nn.conv2d_ch_axis))
        x = self.uconv52(x)
        x = self.uconv51(x)

        x = self.up4(x)
        x = self.uconv43(tf.concat([x, x4], axis=nn.conv2d_ch_axis))
        x = self.uconv42(x)
        x = self.uconv41(x)

        x = self.up3(x)
        x = self.uconv33(tf.concat([x, x3], axis=nn.conv2d_ch_axis))
        x = self.uconv32(x)
        x = self.uconv31(x)

        x = self.up2(x)
        x = self.uconv22(tf.concat([x, x2], axis=nn.conv2d_ch_axis))
        x = self.uconv21(x)

        x = self.up1(x)
        x = self.uconv12(tf.concat([x, x1], axis=nn.conv2d_ch_axis))
        x = self.uconv11(x)

        x = self.up0(x)
        x = self.uconv02(tf.concat([x, x0], axis=nn.conv2d_ch_axis))
        x = self.uconv01(x)

        logits = self.out_conv(x)
        return logits, tf.nn.sigmoid(logits)
示例#12
0
    def forward(self, inp, pretrain=False):
        x = inp

        x = self.conv01(x)
        x = x0 = self.conv02(x)
        x = self.bp0(x)

        x = self.conv11(x)
        x = x1 = self.conv12(x)
        x = self.bp1(x)

        x = self.conv21(x)
        x = x2 = self.conv22(x)
        x = self.bp2(x)

        x = self.conv31(x)
        x = self.conv32(x)
        x = x3 = self.conv33(x)
        x = self.bp3(x)

        x = self.conv41(x)
        x = self.conv42(x)
        x = x4 = self.conv43(x)
        x = self.bp4(x)

        x = self.conv51(x)
        x = self.conv52(x)
        x = x5 = self.conv53(x)
        x = self.bp5(x)

        x = nn.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = nn.reshape_4D(x, 4, 4, self.base_ch * 8)

        x = self.up5(x)
        if pretrain:
            x5 = tf.zeros_like(x5)
        x = self.uconv53(tf.concat([x, x5], axis=nn.conv2d_ch_axis))
        x = self.uconv52(x)
        x = self.uconv51(x)

        x = self.up4(x)
        if pretrain:
            x4 = tf.zeros_like(x4)
        x = self.uconv43(tf.concat([x, x4], axis=nn.conv2d_ch_axis))
        x = self.uconv42(x)
        x = self.uconv41(x)

        x = self.up3(x)
        if pretrain:
            x3 = tf.zeros_like(x3)
        x = self.uconv33(tf.concat([x, x3], axis=nn.conv2d_ch_axis))
        x = self.uconv32(x)
        x = self.uconv31(x)

        x = self.up2(x)
        if pretrain:
            x2 = tf.zeros_like(x2)
        x = self.uconv22(tf.concat([x, x2], axis=nn.conv2d_ch_axis))
        x = self.uconv21(x)

        x = self.up1(x)
        if pretrain:
            x1 = tf.zeros_like(x1)
        x = self.uconv12(tf.concat([x, x1], axis=nn.conv2d_ch_axis))
        x = self.uconv11(x)

        x = self.up0(x)
        if pretrain:
            x0 = tf.zeros_like(x0)
        x = self.uconv02(tf.concat([x, x0], axis=nn.conv2d_ch_axis))
        x = self.uconv01(x)

        logits = self.out_conv(x)
        return logits, tf.nn.sigmoid(logits)