Пример #1
0
    def discriminator(self, image, t_text_embedding):
        update_collection = tf.GraphKeys.UPDATE_OPS
        with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE):
            h0 = ops.lrelu(
                ops.conv2d_sn(image,
                              self.options['df_dim'],
                              spectral_normed=True,
                              update_collection=update_collection,
                              name='d_h0_conv'))  #32
            h1 = ops.lrelu(
                self.d_bn1(
                    ops.conv2d_sn(h0,
                                  self.options['df_dim'] * 2,
                                  spectral_normed=True,
                                  update_collection=update_collection,
                                  name='d_h1_conv')))  #16
            h2 = ops.lrelu(
                self.d_bn2(
                    ops.conv2d_sn(h1,
                                  self.options['df_dim'] * 4,
                                  spectral_normed=True,
                                  update_collection=update_collection,
                                  name='d_h2_conv')))  #8
            h3 = ops.lrelu(
                self.d_bn3(
                    ops.conv2d_sn(h2,
                                  self.options['df_dim'] * 8,
                                  spectral_normed=True,
                                  update_collection=update_collection,
                                  name='d_h3_conv')))  #4

            # ADD TEXT EMBEDDING TO THE NETWORK
            reduced_text_embeddings = ops.lrelu(
                ops.linear(t_text_embedding, self.options['t_dim'],
                           'd_embedding'))
            reduced_text_embeddings = tf.expand_dims(reduced_text_embeddings,
                                                     1)
            reduced_text_embeddings = tf.expand_dims(reduced_text_embeddings,
                                                     2)
            tiled_embeddings = tf.tile(reduced_text_embeddings, [1, 4, 4, 1],
                                       name='tiled_embeddings')

            h3_concat = tf.concat([h3, tiled_embeddings], 3, name='h3_concat')
            h3_new = ops.lrelu(
                self.d_bn4(
                    ops.conv2d_sn(h3_concat,
                                  self.options['df_dim'] * 8,
                                  1,
                                  1,
                                  1,
                                  1,
                                  spectral_normed=True,
                                  update_collection=update_collection,
                                  name='d_h3_conv_new')))  #4

            h4 = ops.linear(
                tf.reshape(h3_new, [self.options['batch_size'], -1]), 1,
                'd_h3_lin')

        return h4, h4
Пример #2
0
    def discriminator(self, image, t_text_embedding):
        update_collection = tf.GraphKeys.UPDATE_OPS
        with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE):
            h0 = ops.lrelu(
                ops.conv2d_sn(image,
                              self.options['df_dim'],
                              spectral_normed=True,
                              update_collection=update_collection,
                              name='d_h0_conv'))  #32
            h1 = ops.lrelu(
                self.d_bn1(
                    ops.conv2d_sn(h0,
                                  self.options['df_dim'] * 2,
                                  spectral_normed=True,
                                  update_collection=update_collection,
                                  name='d_h1_conv')))  #16
            h2 = ops.lrelu(
                self.d_bn2(
                    ops.conv2d_sn(h1,
                                  self.options['df_dim'] * 4,
                                  spectral_normed=True,
                                  update_collection=update_collection,
                                  name='d_h2_conv')))  #8
            h3 = ops.lrelu(
                self.d_bn3(
                    ops.conv2d_sn(h2,
                                  self.options['df_dim'] * 8,
                                  spectral_normed=True,
                                  update_collection=update_collection,
                                  name='d_h3_conv')))  #4
            h3_new = ops.lrelu(
                self.d_bn4(
                    ops.conv2d_sn(h3,
                                  self.options['df_dim'] * 8,
                                  1,
                                  1,
                                  1,
                                  1,
                                  spectral_normed=True,
                                  update_collection=update_collection,
                                  name='d_h3_conv_new')))  #4
            h3_new = tf.reshape(h3_new, [self.options['batch_size'], -1])
            image_embedding = ops.linear(h3_new, self.options['t_dim'],
                                         'd_h3_embedding')

            # Embedding matrix of condition
            reduced_text_embeddings = ops.linear(t_text_embedding,
                                                 self.options['t_dim'],
                                                 'd_embedding')

            # Scalar output function
            h4 = ops.linear(image_embedding, 1, 'd_scalar_output')

            discriminator_output_logit = tf.reduce_sum(tf.multiply(
                reduced_text_embeddings, image_embedding),
                                                       1,
                                                       keepdims=True) + h4

        return discriminator_output_logit, discriminator_output_logit