Пример #1
0
                def forward(self, inp):
                    z = inp

                    x = self.upscale0(z)
                    x = self.res0(x)
                    x = self.upscale1(x)
                    x = self.res1(x)
                    x = self.upscale2(x)
                    x = self.res2(x)


                    if 'd' in opts:
                        x0 = tf.nn.sigmoid(self.out_conv(x))
                        x0 = nn.upsample2d(x0)
                        x1 = tf.nn.sigmoid(self.out_conv1(x))
                        x1 = nn.upsample2d(x1)
                        x2 = tf.nn.sigmoid(self.out_conv2(x))
                        x2 = nn.upsample2d(x2)
                        x3 = tf.nn.sigmoid(self.out_conv3(x))
                        x3 = nn.upsample2d(x3)

                        if nn.data_format == "NHWC":
                            tile_cfg = ( 1, resolution // 2, resolution //2, 1)
                        else:
                            tile_cfg = ( 1, 1, resolution // 2, resolution //2 )

                        z0 =  tf.concat ( ( tf.concat ( (  tf.ones ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ),
                                            tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] )

                        z0 = tf.tile ( z0, tile_cfg )

                        z1 =  tf.concat ( ( tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.ones ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ),
                                            tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] )
                        z1 = tf.tile ( z1, tile_cfg )

                        z2 =  tf.concat ( ( tf.concat ( (  tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ),
                                            tf.concat ( (  tf.ones ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] )
                        z2 = tf.tile ( z2, tile_cfg )

                        z3 =  tf.concat ( ( tf.concat ( (  tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ),
                                            tf.concat ( (  tf.zeros ( (1,1,1,1) ), tf.ones ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] )
                        z3 = tf.tile ( z3, tile_cfg )

                        x = x0*z0 + x1*z1 + x2*z2 + x3*z3
                    else:
                        x = tf.nn.sigmoid(self.out_conv(x))


                    m = self.upscalem0(z)
                    m = self.upscalem1(m)
                    m = self.upscalem2(m)
                    if 'd' in opts:
                        m = self.upscalem3(m)
                    m = tf.nn.sigmoid(self.out_convm(m))

                    return x, m
Пример #2
0
            def forward(self, x):
                if not self.zero_level:
                    x = nn.upsample2d(x)

                x = tf.nn.leaky_relu(self.conv1(x), 0.2)
                x = tf.nn.leaky_relu(self.conv2(x), 0.2)
                return x
Пример #3
0
            def forward(self, input):
                up1 = self.b1(input)

                low1 = tf.nn.avg_pool(input, [1,2,2,1], [1,2,2,1], 'VALID')
                low1 = self.b2 (low1)

                low2 = self.b2_plus(low1)
                low3 = self.b3(low2)

                up2 = nn.upsample2d(low3)

                return up1+up2
Пример #4
0
            def forward(self, stage, inp, alpha=None, inter=None):
                x = inp

                for level in range(stage + 1):
                    if level in self.to_rgbs:

                        if level == stage and stage > 0:
                            prev_level = level - 1
                            #prev_x, prev_xm = (inter.to_rgbs[prev_level] if inter is not None and prev_level in inter.to_rgbs else self.to_rgbs[prev_level])(x)
                            prev_x, prev_xm = self.to_rgbs[prev_level](x)

                            prev_x = nn.upsample2d(prev_x)
                            prev_xm = nn.upsample2d(prev_xm)

                        if level != 0:
                            x = self.dec_blocks[level](x)

                        if level == stage:
                            x, xm = self.to_rgbs[level](x)
                            if stage > 0:
                                x = x * alpha + prev_x * (1 - alpha)
                                xm = xm * alpha + prev_xm * (1 - alpha)
                            return x, xm
                return x