Ejemplo n.º 1
0
    def discriminator(self, x, reuse=None):
        """
        :param x: images
        :param y: labels
        :param reuse: re-usable
        :return: classification, probability (fake or real), network
        """
        with tf.variable_scope("discriminator", reuse=reuse):
            f = self.gf_dim

            x = t.conv2d_alt(x, f, 4, 2, pad=1, sn=True, name='disc-conv2d-1')
            x = tf.nn.leaky_relu(x, alpha=0.1)

            for i in range(self.n_layer // 2):
                x = t.conv2d_alt(x, f * 2, 4, 2, pad=1, sn=True, name='disc-conv2d-%d' % (i + 2))
                x = tf.nn.leaky_relu(x, alpha=0.1)

                f *= 2

            # Self-Attention Layer
            x = self.attention(x, f, reuse=reuse)

            for i in range(self.n_layer // 2, self.n_layer):
                x = t.conv2d_alt(x, f * 2, 4, 2, pad=1, sn=True, name='disc-conv2d-%d' % (i + 2))
                x = tf.nn.leaky_relu(x, alpha=0.1)

                f *= 2

            x = t.flatten(x)

            x = t.dense_alt(x, 1, sn=True, name='disc-fc-1')
            return x
Ejemplo n.º 2
0
    def attention(x, f_, reuse=None):
        with tf.variable_scope("attention", reuse=reuse):
            f = t.conv2d_alt(x,
                             f_ // 8,
                             1,
                             1,
                             sn=True,
                             name='attention-conv2d-f')
            g = t.conv2d_alt(x,
                             f_ // 8,
                             1,
                             1,
                             sn=True,
                             name='attention-conv2d-g')
            h = t.conv2d_alt(x, f_, 1, 1, sn=True, name='attention-conv2d-h')

            f, g, h = t.hw_flatten(f), t.hw_flatten(g), t.hw_flatten(h)

            s = tf.matmul(g, f, transpose_b=True)
            attention_map = tf.nn.softmax(s, axis=-1, name='attention_map')

            o = tf.reshape(tf.matmul(attention_map, h), shape=x.get_shape())
            gamma = tf.get_variable('gamma',
                                    shape=[1],
                                    initializer=tf.zeros_initializer())

            x = gamma * o + x
            return x
Ejemplo n.º 3
0
    def generator(self, z, reuse=None, is_train=True):
        """
        :param z: noise
        :param y: image label
        :param reuse: re-usable
        :param is_train: trainable
        :return: prob
        """
        with tf.variable_scope("generator", reuse=reuse):
            f = self.gf_dim * 8

            x = t.dense_alt(z, 4 * 4 * f, sn=True, name='gen-fc-1')

            x = tf.reshape(x, (-1, 4, 4, f))

            for i in range(self.n_layer // 2):
                if self.up_sampling:
                    x = t.up_sampling(x, interp=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
                    x = t.conv2d_alt(x, f // 2, 5, 1, pad=2, sn=True, use_bias=False, name='gen-conv2d-%d' % (i + 1))
                else:
                    x = t.deconv2d_alt(x, f // 2, 4, 2, sn=True, use_bias=False, name='gen-deconv2d-%d' % (i + 1))

                x = t.batch_norm(x, is_train=is_train, name='gen-bn-%d' % i)
                x = tf.nn.relu(x)

                f //= 2

            # Self-Attention Layer
            x = self.attention(x, f, reuse=reuse)

            for i in range(self.n_layer // 2, self.n_layer):
                if self.up_sampling:
                    x = t.up_sampling(x, interp=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
                    x = t.conv2d_alt(x, f // 2, 5, 1, pad=2, sn=True, use_bias=False, name='gen-conv2d-%d' % (i + 1))
                else:
                    x = t.deconv2d_alt(x, f // 2, 4, 2, sn=True, use_bias=False, name='gen-deconv2d-%d' % (i + 1))

                x = t.batch_norm(x, is_train=is_train, name='gen-bn-%d' % i)
                x = tf.nn.relu(x)

                f //= 2

            x = t.conv2d_alt(x, self.channel, 5, 1, pad=2, sn=True, name='gen-conv2d-%d' % (self.n_layer + 1))
            x = tf.nn.tanh(x)
            return x
Ejemplo n.º 4
0
    def res_block(x, f, scale_type, use_bn=True, name=""):
        with tf.variable_scope("res_block-%s" % name):
            assert scale_type in ["up", "down"]
            scale_up = False if scale_type == "down" else True

            ssc = x

            x = t.batch_norm(x, name="bn-1") if use_bn else x
            x = tf.nn.relu(x)
            x = t.conv2d_alt(x, f, sn=True, name="conv2d-1")

            x = t.batch_norm(x, name="bn-2") if use_bn else x
            x = tf.nn.relu(x)

            if not scale_up:
                x = t.conv2d_alt(x, f, sn=True, name="conv2d-2")
                x = tf.layers.average_pooling2d(x, pool_size=(2, 2))
            else:
                x = t.deconv2d_alt(x, f, sn=True, name="up-sampling")

            return x + ssc
Ejemplo n.º 5
0
    def generator(self, z, reuse=None):
        """
        :param z: noise
        :param reuse: re-usable
        :return: prob
        """
        with tf.variable_scope("generator", reuse=reuse):
            # split
            z = tf.split(z, num_or_size_splits=4, axis=-1)  # expected [None, 32] * 4

            # linear projection
            x = t.dense_alt(z, f=4 * 4 * 16 * self.channel, sn=True, use_bias=False, name="gen-dense-1")
            x = tf.nn.relu(x)

            x = tf.reshape(x, (-1, 4, 4, 16 * self.channel))

            res = x

            f = 16 * self.channel
            for i in range(4):
                res = self.res_block(res,
                                     f=f,
                                     scale_type="up",
                                     name="gen-res%d" % (i + 1))
                f //= 2

            x = self.self_attention(res, f_=f * 2)

            x = self.res_block(x, f=1 * self.channel, scale_type="up", name="gen-res4")

            x = t.batch_norm(x, name="gen-bn-last")  # <- noise
            x = tf.nn.relu(x)
            x = t.conv2d_alt(x, f=self.channel, k=3, sn=True, name="gen-conv2d-last")

            x = tf.nn.tanh(x)
            return x