def discriminator(x, progress, num_filters_fn, resolution_schedule, num_blocks=None, kernel_size=3, simple_arch=False, scope='progressive_gan_discriminator', reuse=None): """Discriminator network for the progressive GAN model. Args: x: A `Tensor`of NHWC format representing images of size `resolution`. progress: A scalar float `Tensor` of training progress. num_filters_fn: A function that maps `block_id` to # of filters for the block. resolution_schedule: An object of `ResolutionSchedule`. num_blocks: An integer of number of blocks. None means maximum number of blocks, i.e. `resolution.schedule.num_resolutions`. Defaults to None. kernel_size: An integer of convolution kernel size. simple_arch: Bool, use a simple architecture. scope: A string or variable scope. reuse: Whether to reuse `scope`. Defaults to None which means to inherit the reuse option of the parent scope. Returns: A `Tensor` of model output and a dictionary of model end points. """ he_init = tf_slim.variance_scaling_initializer() if num_blocks is None: num_blocks = resolution_schedule.num_resolutions def _conv2d(scope, x, kernel_size, filters, padding='SAME'): return layers.custom_conv2d( x=x, filters=filters, kernel_size=kernel_size, padding=padding, activation=tf.nn.leaky_relu, he_initializer_slope=0.0, scope=scope) def _from_rgb(x, block_id): return _conv2d('from_rgb', x, 1, num_filters_fn(block_id)) if resolution_schedule.scale_mode == 'H': strides = (resolution_schedule.scale_base, 1) else: strides = (resolution_schedule.scale_base, resolution_schedule.scale_base) end_points = {} with tf.variable_scope(scope, reuse=reuse): x0 = x end_points['rgb'] = x0 lods = [] for block_id in range(num_blocks, 0, -1): with tf.variable_scope(block_name(block_id)): scale = resolution_schedule.scale_factor(block_id) lod = resolution_schedule.downscale(x0, scale) end_points['downscaled_rgb_{}'.format(block_id)] = lod if simple_arch: lod = tf.layers.conv2d( lod, num_filters_fn(block_id), kernel_size=1, padding='SAME', name='from_rgb', kernel_initializer=he_init) lod = tf.nn.relu(lod) else: lod = _from_rgb(lod, block_id) # alpha_i is used to replace lod_select. alpha = _discriminator_alpha(block_id, progress) end_points['alpha_{}'.format(block_id)] = alpha lods.append((lod, alpha)) lods_iter = iter(lods) x, _ = next(lods_iter) for block_id in range(num_blocks, 1, -1): with tf.variable_scope(block_name(block_id)): if simple_arch: x = tf.layers.conv2d( x, num_filters_fn(block_id-1), strides=strides, kernel_size=kernel_size, padding='SAME', name='conv', kernel_initializer=he_init) x = tf.nn.relu(x) else: x = _conv2d('conv0', x, kernel_size, num_filters_fn(block_id)) x = _conv2d('conv1', x, kernel_size, num_filters_fn(block_id - 1)) x = resolution_schedule.downscale(x, resolution_schedule.scale_base) lod, alpha = next(lods_iter) x = alpha * lod + (1.0 - alpha) * x with tf.variable_scope(block_name(1)): x = layers.scalar_concat(x, layers.minibatch_mean_stddev(x)) if simple_arch: x = tf.reshape(x, [tf.shape(x)[0], -1]) # flatten x = tf.layers.dense(x, num_filters_fn(0), name='last_conv', kernel_initializer=he_init) x = tf.reshape(x, [tf.shape(x)[0], 1, 1, num_filters_fn(0)]) x = tf.nn.relu(x) else: x = _conv2d('conv0', x, kernel_size, num_filters_fn(1)) x = _conv2d('conv1', x, resolution_schedule.start_resolutions, num_filters_fn(0), 'VALID') end_points['last_conv'] = x if simple_arch: logits = tf.layers.dense(x, 1, name='logits', kernel_initializer=he_init) else: logits = layers.custom_dense(x=x, units=1, scope='logits') end_points['logits'] = logits return logits, end_points
def discriminator(x, progress, num_filters_fn, resolution_schedule, num_blocks=None, kernel_size=3, simple_arch=False, scope='progressive_gan_discriminator', reuse=None): """Discriminator network for the progressive GAN model. Args: x: A `Tensor`of NHWC format representing images of size `resolution`. progress: A scalar float `Tensor` of training progress. num_filters_fn: A function that maps `block_id` to # of filters for the block. resolution_schedule: An object of `ResolutionSchedule`. num_blocks: An integer of number of blocks. None means maximum number of blocks, i.e. `resolution.schedule.num_resolutions`. Defaults to None. kernel_size: An integer of convolution kernel size. simple_arch: Bool, use a simple architecture. scope: A string or variable scope. reuse: Whether to reuse `scope`. Defaults to None which means to inherit the reuse option of the parent scope. Returns: A `Tensor` of model output and a dictionary of model end points. """ he_init = tf.contrib.layers.variance_scaling_initializer() if num_blocks is None: num_blocks = resolution_schedule.num_resolutions def _conv2d(scope, x, kernel_size, filters, padding='SAME'): return layers.custom_conv2d( x=x, filters=filters, kernel_size=kernel_size, padding=padding, activation=tf.nn.leaky_relu, he_initializer_slope=0.0, scope=scope) def _from_rgb(x, block_id): return _conv2d('from_rgb', x, 1, num_filters_fn(block_id)) if resolution_schedule.scale_mode == 'H': strides = (resolution_schedule.scale_base, 1) else: strides = (resolution_schedule.scale_base, resolution_schedule.scale_base) end_points = {} with tf.variable_scope(scope, reuse=reuse): x0 = x end_points['rgb'] = x0 lods = [] for block_id in range(num_blocks, 0, -1): with tf.variable_scope(block_name(block_id)): scale = resolution_schedule.scale_factor(block_id) lod = resolution_schedule.downscale(x0, scale) end_points['downscaled_rgb_{}'.format(block_id)] = lod if simple_arch: lod = tf.layers.conv2d( lod, num_filters_fn(block_id), kernel_size=1, padding='SAME', name='from_rgb', kernel_initializer=he_init) lod = tf.nn.relu(lod) else: lod = _from_rgb(lod, block_id) # alpha_i is used to replace lod_select. alpha = _discriminator_alpha(block_id, progress) end_points['alpha_{}'.format(block_id)] = alpha lods.append((lod, alpha)) lods_iter = iter(lods) x, _ = six.next(lods_iter) for block_id in range(num_blocks, 1, -1): with tf.variable_scope(block_name(block_id)): if simple_arch: x = tf.layers.conv2d( x, num_filters_fn(block_id-1), strides=strides, kernel_size=kernel_size, padding='SAME', name='conv', kernel_initializer=he_init) x = tf.nn.relu(x) else: x = _conv2d('conv0', x, kernel_size, num_filters_fn(block_id)) x = _conv2d('conv1', x, kernel_size, num_filters_fn(block_id - 1)) x = resolution_schedule.downscale(x, resolution_schedule.scale_base) lod, alpha = six.next(lods_iter) x = alpha * lod + (1.0 - alpha) * x with tf.variable_scope(block_name(1)): x = layers.scalar_concat(x, layers.minibatch_mean_stddev(x)) if simple_arch: x = tf.reshape(x, [tf.shape(x)[0], -1]) # flatten x = tf.layers.dense(x, num_filters_fn(0), name='last_conv', kernel_initializer=he_init) x = tf.reshape(x, [tf.shape(x)[0], 1, 1, num_filters_fn(0)]) x = tf.nn.relu(x) else: x = _conv2d('conv0', x, kernel_size, num_filters_fn(1)) x = _conv2d('conv1', x, resolution_schedule.start_resolutions, num_filters_fn(0), 'VALID') end_points['last_conv'] = x if simple_arch: logits = tf.layers.dense(x, 1, name='logits', kernel_initializer=he_init) else: logits = layers.custom_dense(x=x, units=1, scope='logits') end_points['logits'] = logits return logits, end_points