Exemple #1
0
 def forward(self, inp):
     x = self.dense1(inp)
     x = self.dense2(x)
     x = nn.reshape_4D(x, lowest_dense_res, lowest_dense_res,
                       self.ae_out_ch)
     x = self.upscale1(x)
     return x
                def forward(self, inp):
                    x = inp
                    if 'u' in opts:
                        x = self.dense_norm(x)
                    x = self.dense1(x)
                    x = self.dense2(x)
                    x = nn.reshape_4D(x, lowest_dense_res, lowest_dense_res,
                                      self.ae_out_ch)

                    if use_fp16:
                        x = tf.cast(x, tf.float16)
                    x = self.upscale1(x)
                    return x
                def forward(self, inp):
                    x = inp
                    if 'u' in opts:
                        x = self.dense_norm(x)

                    for d in self.dense:
                        x = d(x)
                        x = tf.nn.leaky_relu(x, 0.1)

                    x = nn.reshape_4D(x, lowest_dense_res, lowest_dense_res,
                                      self.ae_out_ch)
                    x = self.upscale1(x)

                    return x
Exemple #4
0
            def forward(self, inp, stage):
                x = inp

                for level in range(stage, -1, -1):
                    if stage in self.enc_blocks:
                        if level == stage:
                            x = self.from_rgbs[level](x)
                        x = self.enc_blocks[level](x)

                x = nn.flatten(x)
                x = self.dense_norm(x)
                x = nn.reshape_4D(x, 1, 1, self.max_ch)

                return x
Exemple #5
0
            def forward(self, stage, inp, prev_inp=None, alpha=None):
                x = inp

                for level in range(stage, -1, -1):
                    if stage in self.from_rgbs:
                        if level == stage:
                            x = self.from_rgbs[level](x)
                        elif level == stage - 1:
                            x = x * alpha + self.from_rgbs[level](prev_inp) * (
                                1 - alpha)

                        if level != 0:
                            x = self.enc_blocks[level](x)

                x = nn.flatten(x)
                x = self.dense_norm(x)
                x = self.ae_dense1(x)
                x = self.ae_dense2(x)
                x = nn.reshape_4D(x, ae_res, ae_res, ae_ch)

                return x
Exemple #6
0
 def forward(self, inp):
     x = inp
     x = self.dense2(x)
     x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
     return x
Exemple #7
0
 def forward(self, inp):
     x = inp
     x = self.dense2(x)
     x = nn.reshape_4D (x, inter_res, inter_res, inter_dims)
     return x
Exemple #8
0
    def forward(self, inp):
        x = inp

        x = self.conv01(x)
        x = x0 = self.conv02(x)
        x = self.bp0(x)

        x = self.conv11(x)
        x = x1 = self.conv12(x)
        x = self.bp1(x)

        x = self.conv21(x)
        x = x2 = self.conv22(x)
        x = self.bp2(x)

        x = self.conv31(x)
        x = self.conv32(x)
        x = x3 = self.conv33(x)
        x = self.bp3(x)

        x = self.conv41(x)
        x = self.conv42(x)
        x = x4 = self.conv43(x)
        x = self.bp4(x)

        x = self.conv51(x)
        x = self.conv52(x)
        x = x5 = self.conv53(x)
        x = self.bp5(x)

        x = nn.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = nn.reshape_4D(x, 4, 4, self.base_ch * 8)

        x = self.up5(x)
        x = self.uconv53(tf.concat([x, x5], axis=nn.conv2d_ch_axis))
        x = self.uconv52(x)
        x = self.uconv51(x)

        x = self.up4(x)
        x = self.uconv43(tf.concat([x, x4], axis=nn.conv2d_ch_axis))
        x = self.uconv42(x)
        x = self.uconv41(x)

        x = self.up3(x)
        x = self.uconv33(tf.concat([x, x3], axis=nn.conv2d_ch_axis))
        x = self.uconv32(x)
        x = self.uconv31(x)

        x = self.up2(x)
        x = self.uconv22(tf.concat([x, x2], axis=nn.conv2d_ch_axis))
        x = self.uconv21(x)

        x = self.up1(x)
        x = self.uconv12(tf.concat([x, x1], axis=nn.conv2d_ch_axis))
        x = self.uconv11(x)

        x = self.up0(x)
        x = self.uconv02(tf.concat([x, x0], axis=nn.conv2d_ch_axis))
        x = self.uconv01(x)

        logits = self.out_conv(x)
        return logits, tf.nn.sigmoid(logits)
Exemple #9
0
    def forward(self, inp, pretrain=False):
        x = inp

        x = self.conv01(x)
        x = x0 = self.conv02(x)
        x = self.bp0(x)

        x = self.conv11(x)
        x = x1 = self.conv12(x)
        x = self.bp1(x)

        x = self.conv21(x)
        x = x2 = self.conv22(x)
        x = self.bp2(x)

        x = self.conv31(x)
        x = self.conv32(x)
        x = x3 = self.conv33(x)
        x = self.bp3(x)

        x = self.conv41(x)
        x = self.conv42(x)
        x = x4 = self.conv43(x)
        x = self.bp4(x)

        x = self.conv51(x)
        x = self.conv52(x)
        x = x5 = self.conv53(x)
        x = self.bp5(x)

        x = nn.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = nn.reshape_4D(x, 4, 4, self.base_ch * 8)

        x = self.up5(x)
        if pretrain:
            x5 = tf.zeros_like(x5)
        x = self.uconv53(tf.concat([x, x5], axis=nn.conv2d_ch_axis))
        x = self.uconv52(x)
        x = self.uconv51(x)

        x = self.up4(x)
        if pretrain:
            x4 = tf.zeros_like(x4)
        x = self.uconv43(tf.concat([x, x4], axis=nn.conv2d_ch_axis))
        x = self.uconv42(x)
        x = self.uconv41(x)

        x = self.up3(x)
        if pretrain:
            x3 = tf.zeros_like(x3)
        x = self.uconv33(tf.concat([x, x3], axis=nn.conv2d_ch_axis))
        x = self.uconv32(x)
        x = self.uconv31(x)

        x = self.up2(x)
        if pretrain:
            x2 = tf.zeros_like(x2)
        x = self.uconv22(tf.concat([x, x2], axis=nn.conv2d_ch_axis))
        x = self.uconv21(x)

        x = self.up1(x)
        if pretrain:
            x1 = tf.zeros_like(x1)
        x = self.uconv12(tf.concat([x, x1], axis=nn.conv2d_ch_axis))
        x = self.uconv11(x)

        x = self.up0(x)
        if pretrain:
            x0 = tf.zeros_like(x0)
        x = self.uconv02(tf.concat([x, x0], axis=nn.conv2d_ch_axis))
        x = self.uconv01(x)

        logits = self.out_conv(x)
        return logits, tf.nn.sigmoid(logits)