Esempio n. 1
0
    def decode(self, x, reuse=False, is_training=True):
        with tf.variable_scope(self.name):
            with tf.variable_scope('Decoder') as vs:
                if reuse:
                    vs.reuse_variables()
                x = dense(x, 128, activation_='lrelu')
                x = dense(x,
                          self.feature_shape[0] * self.feature_shape[1] *
                          self.feature_shape[2],
                          activation_='lrelu')
                x = reshape(x, self.feature_shape)

                for i in range(2)[::-1]:
                    filters = self.first_filters * (2**i)
                    x = conv_block(x,
                                   filters=filters,
                                   sampling=self.upsampling,
                                   is_training=is_training,
                                   **self.conv_block_params)
                    x = conv_block(x,
                                   filters=filters,
                                   sampling='same',
                                   is_training=is_training,
                                   **self.conv_block_params)

                x = conv_block(x,
                               filters=self.first_filters,
                               sampling=self.upsampling,
                               **self.conv_block_params)
                x = conv_block(x,
                               filters=self.channel,
                               sampling='same',
                               **self.last_conv_block_params)
            return x
Esempio n. 2
0
    def __call__(self, x, reuse=False, is_training=True):
        for u in self.dense_units[:-1]:
            x = dense(x, u)
            x = activation(x, self.dense_params['activation_'])
            if self.dense_params['normalization'] is not None:
                if self.dense_params['normalization'] == 'batch':
                    x = batch_norm(x, is_training)
                elif self.dense_params['normalization'] == 'layer':
                    x = layer_norm(x, is_training)
                else:
                    raise ValueError

            if self.is_dropout:
                x = dropout(x, 0.5, is_training)
        return dense(x, self.dense_units[-1], activation_='softmax')
Esempio n. 3
0
def generator(batch_size, latent_size, args, reuse=False):
    """Adds generator nodes to the graph.

    From noise, applies deconv2d until image is scaled up to match
    the dataset.
    """
    # final_activation = tf.tanh if args.model in ['wgan', 'iwgan'] else tf.nn.sigmoid
    output_dim = 64 * 64 * 3
    with arg_scope([dense, deconv2d],
                   reuse=reuse,
                   use_batch_norm=True,
                   activation=tf.nn.relu):
        z = tf.random_normal([batch_size, latent_size])
        y = dense(z, latent_size, 4 * 4 * 4 * latent_size, name='fc1')
        y = tf.reshape(y, [-1, 4, 4, 4 * latent_size])
        y = deconv2d(y, 4 * latent_size, 2 * latent_size, 5, 2, name='dc1')
        y = deconv2d(y, 2 * latent_size, latent_size, 5, 2, name='dc2')
        y = deconv2d(y, latent_size, int(latent_size / 2), 5, 2, name='dc3')
        y = deconv2d(y,
                     int(latent_size / 2),
                     3,
                     5,
                     2,
                     name='dc4',
                     activation=tf.tanh,
                     use_batch_norm=False)
        y = tf.reshape(y, [-1, output_dim])
    return y
Esempio n. 4
0
def latent(x, batch_size, latent_size, reuse=False):
    """Adds latent nodes for sampling and reparamaterization.

    Args:
    x: Tensor, input images.
    batch_size: Integer, batch size.
    latent_size: Integer, size of latent vector.
    reuse: Boolean, whether to reuse variables.
    """
    with arg_scope([dense], reuse=reuse):
        flat = flatten(x)
        z_mean = dense(flat, 32 * 4 * 4, latent_size, name='d1')
        z_stddev = dense(flat, 32 * 4 * 4, latent_size, name='d2')
        samples = tf.random_normal([batch_size, latent_size], 0, 1)
        z = (z_mean + (z_stddev * samples))
    return (samples, z, z_mean, z_stddev)
Esempio n. 5
0
def latent(x, latent_size, reuse=False):
    """Add latant nodes to the graph.

    Args:
    x: Tensor, output from encoder.
    latent_size: Integer, size of latent vector.
    reuse: Boolean, whether to reuse variables.
    """
    with arg_scope([dense], reuse = reuse):
        x = flatten(x)
        x = dense(x, 32*4*4, latent_size, name='d1')
    return x
Esempio n. 6
0
def discriminator(x, args, reuse=False):
    """Adds discriminator nodes to the graph.

    From the input image, successively applies convolutions with
    striding to scale down layer sizes until we get to a single
    output value, representing the discriminator's estimate of fake
    vs real. The single final output acts similar to a sigmoid
    activation function.

    Args:
    x: Tensor, input.
    args: Argparse structure.
    reuse: Boolean, whether to reuse variables.

    Returns:
    Final output of discriminator pipeline.
    """
    use_bn = False if args.model == 'iwgan' else True
    final_activation = None if args.model in ['wgan', 'iwgan'
                                              ] else tf.nn.sigmoid
    with arg_scope([conv2d],
                   use_batch_norm=use_bn,
                   activation=lrelu,
                   reuse=reuse):
        x = tf.reshape(x, [-1, 64, 64, 3])
        x = conv2d(x,
                   3,
                   args.latent_size,
                   5,
                   2,
                   name='c1',
                   use_batch_norm=False)
        x = conv2d(x, args.latent_size, args.latent_size * 2, 5, 2, name='c2')
        x = conv2d(x,
                   args.latent_size * 2,
                   args.latent_size * 4,
                   5,
                   2,
                   name='c3')
        x = tf.reshape(x, [-1, 4 * 4 * 4 * args.latent_size])
        x = dense(x,
                  4 * 4 * 4 * args.latent_size,
                  1,
                  use_batch_norm=False,
                  activation=final_activation,
                  name='fc2',
                  reuse=reuse)
        x = tf.reshape(x, [-1])
    return x
Esempio n. 7
0
    def encode(self, x, reuse=False, is_training=True):
        with tf.variable_scope(self.name):
            with tf.variable_scope('Encoder') as vs:
                if reuse:
                    vs.reuse_variables()

                for i in range(3):
                    filters = self.first_filters * (2**i)
                    x = conv_block(x,
                                   filters=filters,
                                   sampling='same',
                                   is_training=is_training,
                                   **self.conv_block_params)
                    x = conv_block(x,
                                   filters=filters,
                                   sampling='down',
                                   is_training=is_training,
                                   **self.conv_block_params)
                self.feature_shape = x.get_shape().as_list()[1:]

                x = flatten(x)
                x = dense(x, 128, activation_='lrelu')
                x = dense(x, self.latent_dim)
            return x
Esempio n. 8
0
def decoder(x, latent_size, reuse=False):
    """Adds decoder nodes to the graph.

    Args:
      x: Tensor, encoded image representation.
      latent_size: Integer, size of latent vector.
      reuse: Boolean, whether to reuse variables.
    """
    with arg_scope([dense, conv2d, deconv2d],
                       reuse = reuse,
                       activation = tf.nn.relu):
        x = dense(x, latent_size, 32*4*4, name='d1')
        x = tf.reshape(x, [-1, 4, 4, 32]) # un-flatten
        x = conv2d(x,    32,  96, 1,    name='c1')
        x = conv2d(x,    96, 256, 1,    name='c2')
        x = deconv2d(x, 256, 256, 5, 2, name='dc1')
        x = deconv2d(x, 256, 128, 5, 2, name='dc2')
        x = deconv2d(x, 128,  64, 5, 2, name='dc3')
        x = deconv2d(x,  64,   3, 5, 2, name='dc4', activation=tf.nn.tanh)
    return x