def g_block(x, out_channels, is_training, name): """Builds the residual blocks used in the generator. Compared with block, optimized_block always downsamples the spatial resolution of the input vector by a factor of 4. Args: x: The 4D input vector. out_channels: Number of features in the output layer. name: The variable scope name for the block. Returns: A `Tensor` representing the output of the operation. """ with tf.variable_scope(name): bn0 = ops.batch_norm(name='bn0') bn1 = ops.batch_norm(name='bn1') x_0 = x x = tf.nn.relu(bn0(x, train=is_training)) x = usample(x) x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv1') x = tf.nn.relu(bn1(x, train=is_training)) x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv2') x_0 = usample(x_0) x_0 = ops.conv2d(x_0, out_channels, 1, 1, 1, 1, name='conv3') return x_0 + x
def optimized_block(x, out_channels, name, act=tf.nn.relu): """Builds the simplified residual blocks for downsampling. Compared with block, optimized_block always downsamples the spatial resolution of the input vector by a factor of 4. Args: x: The 4D input vector. out_channels: Number of features in the output layer. name: The variable scope name for the block. update_collection: The update collections used in the spectral_normed_weight. act: The activation function used in the block. Returns: A `Tensor` representing the output of the operation. """ with tf.variable_scope(name): x_0 = x x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv1') x = act(x) x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv2') x = dsample_conv(x, "o_dsample_1") x_0 = dsample_conv(x_0, "o_dsample_2") x_0 = ops.conv2d(x_0, out_channels, 1, 1, 1, 1, name='conv3') return x + x_0
def d_block(x, out_channels, name, downsample=True, act=tf.nn.relu): """Builds the residual blocks used in the discriminator in SNGAN. Args: x: The 4D input vector. out_channels: Number of features in the output layer. name: The variable scope name for the block. update_collection: The update collections used in the spectral_normed_weight. downsample: If True, downsample the spatial size the input tensor by a factor of 4. If False, the spatial size of the input tensor is unchanged. act: The activation function used in the block. Returns: A `Tensor` representing the output of the operation. """ with tf.variable_scope(name): input_channels = x.get_shape().as_list()[-1] x_0 = x x = act(x) x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv1') x = act(x) x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv2') if downsample: x = dsample_conv(x, "d_dsample_1") if downsample or input_channels != out_channels: x_0 = ops.conv2d(x_0, out_channels, 1, 1, 1, 1, name='conv3') if downsample: x_0 = dsample_conv(x_0, "d_dsample_2") return x_0 + x
def generator_resnet_stl10(z, x_shape, dim=64, \ num_classes = None, labels = None, \ name = 'generator', reuse=False, \ is_training=True): if labels is not None: labels = tf.squeeze(labels) x_dim = x_shape[0] * x_shape[1] * x_shape[2] is_conditional = num_classes is not None and labels is not None with tf.variable_scope(name, reuse=reuse): act0 = ops.linear(z, dim * 8 * 6 * 6, scope='g_linear0') act0 = tf.reshape(act0, [-1, 6, 6, dim * 8]) # 6 * 6 * dim * 8 if is_conditional: act1 = g_block_cond(act0, dim * 4, num_classes, labels, \ is_training, 'g_block1') # 12 * 12 * dim * 4 act2 = g_block_cond(act1, dim * 2, num_classes, labels, \ is_training, 'g_block2') # 24 * 24 * dim * 2 act3 = g_block_cond(act2, dim * 1, num_classes, labels, \ is_training, 'g_block3') # 48 * 48 * dim * 1 bn = ops.batch_norm(num_classes, name='g_bn') else: act1 = g_block(act0, dim * 4, is_training, 'g_block1') # 12 * 12 * dim * 4 act2 = g_block(act1, dim * 2, is_training, 'g_block2') # 24 * 24 * dim * 2 act3 = g_block(act2, dim * 1, is_training, 'g_block3') # 48 * 48 * dim * 1 bn = ops.batch_norm(name='g_bn') act3 = tf.nn.relu(bn(act3, is_training)) act4 = ops.conv2d(act3, 3, 3, 3, 1, 1, name='g_conv_last') out = tf.nn.sigmoid(act4) return tf.reshape(out, [-1, x_dim])
def encoder_resnet_cifar(x, x_shape, z_dim=128, dim=128, \ num_classes = None, labels = None, \ name = 'encoder', \ update_collection=None, \ reuse=False, is_training=True): global count_reuse if labels is not None: labels = tf.squeeze(labels) dim = dim * 2 # 256 like sn-gan paper act = lrelu is_conditional = num_classes is not None and labels is not None with tf.variable_scope(name, reuse=reuse): image = tf.reshape(x, [-1, x_shape[0], x_shape[1], x_shape[2]]) image = ops.conv2d(image, dim, 3, 3, 1, 1, \ name='e_conv0') # 32 * 32 if is_conditional: act0 = e_block_cond(image, dim,\ num_classes = num_classes, \ labels = labels,\ is_training = is_training,\ name = 'e_block1', act=act)# 16 * 16 else: act0 = e_block(image, dim, is_training = is_training,\ name = 'e_block1', act=act)# 16 * 16 if is_conditional: act1 = e_block_cond(act0, dim, \ num_classes, labels,\ is_training = is_training,\ name = 'e_block2', act=act)# 8 * 8 else: act1 = e_block(act0, dim, is_training, \ name = 'e_block2', act=act)# 8 * 8 if is_conditional: act2 = e_block_cond(act1, dim, \ num_classes, labels, \ is_training = is_training,\ name = 'e_block3', act=act) # 4 * 4 else: act2 = e_block(act1, dim, is_training, \ name = 'e_block3', act=act) # 4 * 4 if is_conditional: bn = ops.batch_norm(num_classes, name='e_bn') else: bn = ops.batch_norm(name='e_bn') act2 = act(bn(act2, is_training)) act2 = tf.reshape(act2, [-1, 4 * 4 * dim]) out = ops.linear(act2, z_dim) return out
def encoder_resnet_stl10(x, x_shape, z_dim=128, dim=64, \ num_classes = None, labels = None, \ name = 'encoder', \ update_collection=None, \ reuse=False, is_training=True): if labels is not None: labels = tf.squeeze(labels) act = lrelu is_conditional = num_classes is not None and labels is not None with tf.variable_scope(name, reuse=reuse): image = tf.reshape(x, [-1, x_shape[0], x_shape[1], x_shape[2]]) image = ops.conv2d(image, dim, 3, 3, 1, 1, \ name='e_conv0') # 48 * 48 * dim if is_conditional: act0 = e_block_cond(image, dim * 2,\ num_classes = num_classes, \ labels = labels,\ is_training = is_training,\ name = 'e_block1', act=act) # 24 * 24 * dim * 2 act1 = e_block_cond(act0, dim * 4, \ num_classes, labels,\ is_training = is_training,\ name = 'e_block2', act=act) # 12 * 12 * dim * 4 act2 = e_block_cond(act1, dim * 8, \ num_classes, labels, \ is_training = is_training,\ name = 'e_block3', act=act) # 6 * 6 * dim * 8 bn = ops.batch_norm(num_classes, name='e_bn') else: act0 = e_block(image, dim * 2, is_training = is_training,\ name = 'e_block1', act=act) # 24 * 24 * dim * 2 act1 = e_block(act0, dim * 4, is_training, \ name = 'e_block2', act=act) # 12 * 12 * dim * 4 act2 = e_block(act1, dim * 8, is_training, \ name = 'e_block3', act=act) # 6 * 6 * dim * 8 bn = ops.batch_norm(name='e_bn') act2 = act(bn(act2, is_training)) act2 = tf.reshape(act2, [-1, 6 * 6 * dim * 8]) out = ops.linear(act2, z_dim) return out
def generator_resnet_cifar(z, x_shape, dim=128, \ num_classes = None, labels = None, \ name = 'generator', reuse=False, \ is_training=True): if labels is not None: labels = tf.squeeze(labels) dim = dim * 2 # 256 like sn-gan paper x_dim = x_shape[0] * x_shape[1] * x_shape[2] is_conditional = num_classes is not None and labels is not None with tf.variable_scope(name, reuse=reuse): act0 = ops.linear(z, dim * 4 * 4, scope='g_linear0') act0 = tf.reshape(act0, [-1, 4, 4, dim]) if is_conditional: act1 = g_block_cond(act0, dim, num_classes, labels, \ is_training, 'g_block1') # 8 * 8 else: act1 = g_block(act0, dim, is_training, 'g_block1') # 8 * 8 if is_conditional: act2 = g_block_cond(act1, dim, num_classes, labels, \ is_training, 'g_block2') # 16 * 16 else: act2 = g_block(act1, dim, is_training, 'g_block2') # 16 * 16 if is_conditional: act3 = g_block_cond(act2, dim, num_classes, labels, \ is_training, 'g_block3') # 32 * 32 else: act3 = g_block(act2, dim, is_training, 'g_block3') # 32 * 32 if is_conditional: bn = ops.batch_norm(num_classes, name='g_bn') else: bn = ops.batch_norm(name='g_bn') act3 = tf.nn.relu(bn(act3, is_training)) act4 = ops.conv2d(act3, 3, 3, 3, 1, 1, name='g_conv_last') out = tf.nn.sigmoid(act4) return tf.reshape(out, [-1, x_dim])
def dsample_conv(x, name='dsample'): """Downsamples the image by a factor of 2.""" xd = ops.conv2d(x, x.get_shape().as_list()[-1], 1, 1, 2, 2, name=name) return xd