Esempio n. 1
0
def g_block(x, out_channels, training, name):
    """Builds the residual blocks used in the generator.

  Compared with block, optimized_block always downsamples the spatial
  resolution of the input vector by a factor of 4.

  Args:
    x: The 4D input vector.
    out_channels: Number of features in the output layer.
    name: The variable scope name for the block.
  Returns:
    A `Tensor` representing the output of the operation.
  """
    with tf.variable_scope(name):

        bn0 = ops.batch_norm(name='bn0')
        bn1 = ops.batch_norm(name='bn1')

        x_0 = x
        x = tf.nn.relu(bn0(x, train=training))
        x = usample(x)
        x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv1')
        x = tf.nn.relu(bn1(x, train=training))
        x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv2')

        x_0 = usample(x_0)
        x_0 = ops.conv2d(x_0, out_channels, 1, 1, 1, 1, name='conv3')

        return x_0 + x
Esempio n. 2
0
def e_block(x, out_channels, training, name, downsample=True, \
                                                        act=tf.nn.relu):
    """Builds the residual blocks used in the discriminator in SNGAN.

  Args:
    x: The 4D input vector.
    out_channels: Number of features in the output layer.
    name: The variable scope name for the block.
    update_collection: The update collections used in the
                       spectral_normed_weight.
    downsample: If True, downsample the spatial size the input tensor by
                a factor of 4. If False, the spatial size of the input 
                tensor is unchanged.
    act: The activation function used in the block.
  Returns:
    A `Tensor` representing the output of the operation.
  """
    with tf.variable_scope(name):

        bn0 = ops.batch_norm(name='bn0')
        bn1 = ops.batch_norm(name='bn1')

        input_channels = x.get_shape().as_list()[-1]
        x_0 = x
        x = act(bn0(x, train=training))
        x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv1')
        x = act(bn1(x, train=training))
        x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv2')
        if downsample:
            x = dsample_pool(x, "e_dsample_1")
        if downsample or input_channels != out_channels:
            x_0 = ops.conv2d(x_0, out_channels, 1, 1, 1, 1, name='conv3')
            if downsample:
                x_0 = dsample_pool(x_0, "e_dsample_2")
        return x_0 + x
Esempio n. 3
0
def generator_resnet_cifar(z, x_shape, dim=128, name = 'generator', \
                                            reuse=False, training=True):

    dim = dim * 2  # 256 like sn-gan paper
    x_dim = x_shape[0] * x_shape[1] * x_shape[2]
    with tf.variable_scope(name, reuse=reuse):
        act0 = ops.linear(z, dim * 4 * 4, scope='g_linear0')
        act0 = tf.reshape(act0, [-1, 4, 4, dim])
        act1 = g_block(act0, dim, training, 'g_block1')  # 8 * 8
        act2 = g_block(act1, dim, training, 'g_block2')  # 16 * 16
        act3 = g_block(act2, dim, training, 'g_block3')  # 32 * 32
        bn = ops.batch_norm(name='g_bn')
        act3 = tf.nn.relu(bn(act3, training))
        act4 = ops.conv2d(act3, 3, 3, 3, 1, 1, name='g_conv_last')
        out = tf.nn.sigmoid(act4)
        return tf.reshape(out, [-1, x_dim])
Esempio n. 4
0
def generator_resnet_stl10(z, x_shape, dim=64, name = 'generator', \
                                            reuse=False, training=True):

    x_dim = x_shape[0] * x_shape[1] * x_shape[2]
    with tf.variable_scope(name, reuse=reuse):
        act0 = ops.linear(z, dim * 8 * 6 * 6, scope='g_linear0')
        act0 = tf.reshape(act0, [-1, 6, 6, dim * 8])  # 6 * 6 * dim * 8
        act1 = g_block(act0, dim * 4, training,
                       'g_block1')  # 12 * 12 * dim * 4
        act2 = g_block(act1, dim * 2, training,
                       'g_block2')  # 24 * 24 * dim * 2
        act3 = g_block(act2, dim * 1, training,
                       'g_block3')  # 48 * 48 * dim * 1
        bn = ops.batch_norm(name='g_bn')
        act3 = tf.nn.relu(bn(act3, training))
        act4 = ops.conv2d(act3, 3, 3, 3, 1, 1, name='g_conv_last')
        out = tf.nn.sigmoid(act4)
        return tf.reshape(out, [-1, x_dim])
Esempio n. 5
0
def encoder_resnet_stl10(x, x_shape, z_dim=128, dim=64, \
                         name = 'encoder', reuse=False, training=True):

    act = lrelu
    with tf.variable_scope(name, reuse=reuse):
        image = tf.reshape(x, [-1, x_shape[0], x_shape[1], x_shape[2]])
        image = ops.conv2d(image, dim, 3, 3, 1, 1, \
                                        name='e_conv0') # 48 * 48 * dim
        act0  = e_block(image, dim * 2, training = training,\
                         name = 'e_block1', act=act) # 24 * 24 * dim * 2
        act1 = e_block(act0, dim * 4, training, \
                         name = 'e_block2', act=act) # 12 * 12 * dim * 4
        act2 = e_block(act1, dim * 8, training, \
                         name =  'e_block3', act=act)# 6 * 6 * dim * 8
        bn = ops.batch_norm(name='e_bn')
        act2 = act(bn(act2, training))
        act2 = tf.reshape(act2, [-1, 6 * 6 * dim * 8])
        out = ops.linear(act2, z_dim)
        return out
Esempio n. 6
0
def encoder_resnet_cifar(x, x_shape, z_dim=128, dim=128, \
                         name = 'encoder', reuse=False, training=True):

    dim = dim * 2  # 256 like sn-gan paper
    act = lrelu
    with tf.variable_scope(name, reuse=reuse):
        image = tf.reshape(x, [-1, x_shape[0], x_shape[1], x_shape[2]])
        image = ops.conv2d(image, dim, 3, 3, 1, 1, \
                                               name='e_conv0') # 32 * 32

        act0 = e_block(image, dim, training = training,\
                                    name = 'e_block1', act=act) # 16 * 16
        act1 = e_block(act0, dim, training, \
                                    name = 'e_block2', act=act) # 8 * 8
        act2 = e_block(act1, dim, training, \
                                    name =  'e_block3', act=act)# 4 * 4
        bn = ops.batch_norm(name='e_bn')
        act2 = act(bn(act2, training))
        act2 = tf.reshape(act2, [-1, 4 * 4 * dim])
        out = ops.linear(act2, z_dim)
        return out