Example #1
0
def transform(input_,
              normalizer_fn=None,
              normalizer_params=None,
              reuse=False,
              trainable=True,
              is_training=True,
              alpha=1.0):
  """Maps content images to stylized images.

  Args:
    input_: Tensor. Batch of input images.
    normalizer_fn: normalization layer function for applying style
      normalization.
    normalizer_params: dict of parameters to pass to the style normalization op.
    reuse: bool. Whether to reuse model parameters. Defaults to False.
    trainable: bool. Should the parameters be marked as trainable?
    is_training: bool. Is it training phase or not?
    alpha: float. Width multiplier to reduce the number of filters used in the
      model and slim it down. Defaults to 1.0, which results
      in the hyper-parameters used in the published paper.


  Returns:
    Tensor. The output of the transformer network.
  """
  with tf.variable_scope('transformer', reuse=reuse):
    with slim.arg_scope([slim.conv2d],
                        activation_fn=tf.nn.relu,
                        normalizer_fn=normalizer_fn,
                        normalizer_params=normalizer_params,
                        weights_initializer=tf.random_normal_initializer(
                            0.0, 0.01),
                        biases_initializer=tf.constant_initializer(0.0),
                        trainable=trainable):
      with slim.arg_scope([slim.conv2d],
                          normalizer_fn=slim.batch_norm,
                          normalizer_params=None,
                          trainable=trainable):
        with slim.arg_scope([slim.batch_norm],
                            is_training=is_training,
                            trainable=trainable):
          with tf.variable_scope('contract'):
            h = model_util.conv2d(input_, 9, 1, int(alpha * 32), 'conv1')
            h = model_util.conv2d(h, 3, 2, int(alpha * 64), 'conv2')
            h = model_util.conv2d(h, 3, 2, int(alpha * 128), 'conv3')
      with tf.variable_scope('residual'):
        h = model_util.residual_block(h, 3, 'residual1')
        h = model_util.residual_block(h, 3, 'residual2')
        h = model_util.residual_block(h, 3, 'residual3')
        h = model_util.residual_block(h, 3, 'residual4')
        h = model_util.residual_block(h, 3, 'residual5')
      with tf.variable_scope('expand'):
        h = model_util.upsampling(h, 3, 2, int(alpha * 64), 'conv1')
        h = model_util.upsampling(h, 3, 2, int(alpha * 32), 'conv2')
        return model_util.upsampling(
            h, 9, 1, 3, 'conv3', activation_fn=tf.nn.sigmoid)
Example #2
0
def transform(input_, normalizer_fn=None, normalizer_params=None,
              reuse=False, trainable=True, is_training=True):
  """Maps content images to stylized images.

  Args:
    input_: Tensor. Batch of input images.
    normalizer_fn: normalization layer function for applying style
        normalization.
    normalizer_params: dict of parameters to pass to the style normalization op.
    reuse: bool. Whether to reuse model parameters. Defaults to False.
    trainable: bool. Should the parameters be marked as trainable?
    is_training: bool. Is it training phase or not?

  Returns:
    Tensor. The output of the transformer network.
  """
  with tf.variable_scope('transformer', reuse=reuse):
    with slim.arg_scope(
        [slim.conv2d],
        activation_fn=tf.nn.relu,
        normalizer_fn=normalizer_fn,
        normalizer_params=normalizer_params,
        weights_initializer=tf.random_normal_initializer(0.0, 0.01),
        biases_initializer=tf.constant_initializer(0.0),
        trainable=trainable):
      with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm,
                          normalizer_params=None,
                          trainable=trainable):
        with slim.arg_scope([slim.batch_norm], is_training=is_training,
                            trainable=trainable):
          with tf.variable_scope('contract'):
            h = model_util.conv2d(input_, 9, 1, 32, 'conv1')
            h = model_util.conv2d(h, 3, 2, 64, 'conv2')
            h = model_util.conv2d(h, 3, 2, 128, 'conv3')
      with tf.variable_scope('residual'):
        h = model_util.residual_block(h, 3, 'residual1')
        h = model_util.residual_block(h, 3, 'residual2')
        h = model_util.residual_block(h, 3, 'residual3')
        h = model_util.residual_block(h, 3, 'residual4')
        h = model_util.residual_block(h, 3, 'residual5')
      with tf.variable_scope('expand'):
        h = model_util.upsampling(h, 3, 2, 64, 'conv1')
        h = model_util.upsampling(h, 3, 2, 32, 'conv2')
        return model_util.upsampling(
            h, 9, 1, 3, 'conv3', activation_fn=tf.nn.sigmoid)