Exemplo n.º 1
0
    def discriminator(self, x, reuse=None):
        """
        :param x: images
        :param y: labels
        :param reuse: re-usable
        :return: classification, probability (fake or real), network
        """
        with tf.variable_scope("discriminator", reuse=reuse):
            f = self.gf_dim

            x = t.conv2d_alt(x, f, 4, 2, pad=1, sn=True, name='disc-conv2d-1')
            x = tf.nn.leaky_relu(x, alpha=0.1)

            for i in range(self.n_layer // 2):
                x = t.conv2d_alt(x, f * 2, 4, 2, pad=1, sn=True, name='disc-conv2d-%d' % (i + 2))
                x = tf.nn.leaky_relu(x, alpha=0.1)

                f *= 2

            # Self-Attention Layer
            x = self.attention(x, f, reuse=reuse)

            for i in range(self.n_layer // 2, self.n_layer):
                x = t.conv2d_alt(x, f * 2, 4, 2, pad=1, sn=True, name='disc-conv2d-%d' % (i + 2))
                x = tf.nn.leaky_relu(x, alpha=0.1)

                f *= 2

            x = t.flatten(x)

            x = t.dense_alt(x, 1, sn=True, name='disc-fc-1')
            return x
    def discriminator(self, x, reuse=None):
        """
        :param x: images
        :param reuse: re-usable
        :return: classification, probability (fake or real), network
        """
        with tf.variable_scope("discriminator", reuse=reuse):
            f = 1 * self.channel

            x = self.res_block(x, f=f, scale_type="down", name="disc-res1")

            x = self.self_attention(x, f_=f)

            for i in range(4):
                f *= 2
                x = self.res_block(x, f=f, scale_type="down", name="disc-res%d" % (i + 1))

            x = self.res_block(x, f=f, scale_type="down", use_bn=False, name="disc-res5")
            x = tf.nn.relu(x)

            with tf.name_scope("global_sum_pooling"):
                x_shape = x.get_shape().as_list()
                x = tf.reduce_mean(x, axis=-1) * (x_shape[1] * x_shape[2])

            x = t.dense_alt(x, 1, sn=True, name='disc-dense-last')
            return x
Exemplo n.º 3
0
    def generator(self, z, reuse=None, is_train=True):
        """
        :param z: noise
        :param y: image label
        :param reuse: re-usable
        :param is_train: trainable
        :return: prob
        """
        with tf.variable_scope("generator", reuse=reuse):
            f = self.gf_dim * 8

            x = t.dense_alt(z, 4 * 4 * f, sn=True, name='gen-fc-1')

            x = tf.reshape(x, (-1, 4, 4, f))

            for i in range(self.n_layer // 2):
                if self.up_sampling:
                    x = t.up_sampling(x, interp=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
                    x = t.conv2d_alt(x, f // 2, 5, 1, pad=2, sn=True, use_bias=False, name='gen-conv2d-%d' % (i + 1))
                else:
                    x = t.deconv2d_alt(x, f // 2, 4, 2, sn=True, use_bias=False, name='gen-deconv2d-%d' % (i + 1))

                x = t.batch_norm(x, is_train=is_train, name='gen-bn-%d' % i)
                x = tf.nn.relu(x)

                f //= 2

            # Self-Attention Layer
            x = self.attention(x, f, reuse=reuse)

            for i in range(self.n_layer // 2, self.n_layer):
                if self.up_sampling:
                    x = t.up_sampling(x, interp=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
                    x = t.conv2d_alt(x, f // 2, 5, 1, pad=2, sn=True, use_bias=False, name='gen-conv2d-%d' % (i + 1))
                else:
                    x = t.deconv2d_alt(x, f // 2, 4, 2, sn=True, use_bias=False, name='gen-deconv2d-%d' % (i + 1))

                x = t.batch_norm(x, is_train=is_train, name='gen-bn-%d' % i)
                x = tf.nn.relu(x)

                f //= 2

            x = t.conv2d_alt(x, self.channel, 5, 1, pad=2, sn=True, name='gen-conv2d-%d' % (self.n_layer + 1))
            x = tf.nn.tanh(x)
            return x
    def generator(self, z, reuse=None):
        """
        :param z: noise
        :param reuse: re-usable
        :return: prob
        """
        with tf.variable_scope("generator", reuse=reuse):
            # split
            z = tf.split(z, num_or_size_splits=4, axis=-1)  # expected [None, 32] * 4

            # linear projection
            x = t.dense_alt(z, f=4 * 4 * 16 * self.channel, sn=True, use_bias=False, name="gen-dense-1")
            x = tf.nn.relu(x)

            x = tf.reshape(x, (-1, 4, 4, 16 * self.channel))

            res = x

            f = 16 * self.channel
            for i in range(4):
                res = self.res_block(res,
                                     f=f,
                                     scale_type="up",
                                     name="gen-res%d" % (i + 1))
                f //= 2

            x = self.self_attention(res, f_=f * 2)

            x = self.res_block(x, f=1 * self.channel, scale_type="up", name="gen-res4")

            x = t.batch_norm(x, name="gen-bn-last")  # <- noise
            x = tf.nn.relu(x)
            x = t.conv2d_alt(x, f=self.channel, k=3, sn=True, name="gen-conv2d-last")

            x = tf.nn.tanh(x)
            return x