def total_loss(content_inputs, style_inputs, stylized_inputs, content_weights,
               style_weights, total_variation_weight, reuse=False):
  """Computes the total loss function.

  The total loss function is composed of a content, a style and a total
  variation term.

  Args:
    content_inputs: Tensor. The input images.
    style_inputs: Tensor. The input images.
    stylized_inputs: Tensor. The stylized input images.
    content_weights: dict mapping layer names to their associated content loss
        weight. Keys that are missing from the dict won't have their content
        loss computed.
    style_weights: dict mapping layer names to their associated style loss
        weight. Keys that are missing from the dict won't have their style
        loss computed.
    total_variation_weight: float. Coefficient for the total variation part of
        the loss.
    reuse: bool. Whether to reuse model parameters. Defaults to False.

  Returns:
    Tensor for the total loss, dict mapping loss names to losses.
  """
  # Propagate the input and its stylized version through VGG16.
  with tf.name_scope('content_endpoints'):
    content_end_points = vgg.vgg_16(content_inputs, reuse=reuse)
  with tf.name_scope('style_endpoints'):
    style_end_points = vgg.vgg_16(style_inputs, reuse=True)
  with tf.name_scope('stylized_endpoints'):
    stylized_end_points = vgg.vgg_16(stylized_inputs, reuse=True)

  # Compute the content loss
  with tf.name_scope('content_loss'):
    total_content_loss, content_loss_dict = content_loss(
        content_end_points, stylized_end_points, content_weights)

  # Compute the style loss
  with tf.name_scope('style_loss'):
    total_style_loss, style_loss_dict = style_loss(
        style_end_points, stylized_end_points, style_weights)

  # Compute the total variation loss
  with tf.name_scope('total_variation_loss'):
    tv_loss, total_variation_loss_dict = learning_utils.total_variation_loss(
        stylized_inputs, total_variation_weight)

  # Compute the total loss
  with tf.name_scope('total_loss'):
    loss = total_content_loss + total_style_loss + tv_loss

  loss_dict = {'total_loss': loss}
  loss_dict.update(content_loss_dict)
  loss_dict.update(style_loss_dict)
  loss_dict.update(total_variation_loss_dict)

  return loss, loss_dict
Exemple #2
0
def total_loss(content_inputs, style_inputs, stylized_inputs, content_weights,
               style_weights, total_variation_weight, reuse=False):
  """Computes the total loss function.

  The total loss function is composed of a content, a style and a total
  variation term.

  Args:
    content_inputs: Tensor. The input images.
    style_inputs: Tensor. The input images.
    stylized_inputs: Tensor. The stylized input images.
    content_weights: dict mapping layer names to their associated content loss
        weight. Keys that are missing from the dict won't have their content
        loss computed.
    style_weights: dict mapping layer names to their associated style loss
        weight. Keys that are missing from the dict won't have their style
        loss computed.
    total_variation_weight: float. Coefficient for the total variation part of
        the loss.
    reuse: bool. Whether to reuse model parameters. Defaults to False.

  Returns:
    Tensor for the total loss, dict mapping loss names to losses.
  """
  # Propagate the input and its stylized version through VGG16.
  with tf.name_scope('content_endpoints'):
    content_end_points = vgg.vgg_16(content_inputs, reuse=reuse)
  with tf.name_scope('style_endpoints'):
    style_end_points = vgg.vgg_16(style_inputs, reuse=True)
  with tf.name_scope('stylized_endpoints'):
    stylized_end_points = vgg.vgg_16(stylized_inputs, reuse=True)

  # Compute the content loss
  with tf.name_scope('content_loss'):
    total_content_loss, content_loss_dict = content_loss(
        content_end_points, stylized_end_points, content_weights)

  # Compute the style loss
  with tf.name_scope('style_loss'):
    total_style_loss, style_loss_dict = style_loss(
        style_end_points, stylized_end_points, style_weights)

  # Compute the total variation loss
  with tf.name_scope('total_variation_loss'):
    tv_loss, total_variation_loss_dict = learning_utils.total_variation_loss(
        stylized_inputs, total_variation_weight)

  # Compute the total loss
  with tf.name_scope('total_loss'):
    loss = total_content_loss + total_style_loss + tv_loss

  loss_dict = {'total_loss': loss}
  loss_dict.update(content_loss_dict)
  loss_dict.update(style_loss_dict)
  loss_dict.update(total_variation_loss_dict)

  return loss, loss_dict
Exemple #3
0
def total_loss(inputs,
               stylized_inputs,
               style_gram_matrices,
               content_weights,
               style_weights,
               reuse=False):
    """Computes the total loss function.

  The total loss function is composed of a content, a style and a total
  variation term.

  Args:
    inputs: Tensor. The input images.
    stylized_inputs: Tensor. The stylized input images.
    style_gram_matrices: dict mapping layer names to their corresponding
        Gram matrices.
    content_weights: dict mapping layer names to their associated content loss
        weight. Keys that are missing from the dict won't have their content
        loss computed.
    style_weights: dict mapping layer names to their associated style loss
        weight. Keys that are missing from the dict won't have their style
        loss computed.
    reuse: bool. Whether to reuse model parameters. Defaults to False.

  Returns:
    Tensor for the total loss, dict mapping loss names to losses.
  """
    # Propagate the the input and its stylized version through VGG16
    end_points = vgg.vgg_16(inputs, reuse=reuse)
    stylized_end_points = vgg.vgg_16(stylized_inputs, reuse=True)

    # Compute the content loss
    total_content_loss, content_loss_dict = content_loss(
        end_points, stylized_end_points, content_weights)

    # Compute the style loss
    total_style_loss, style_loss_dict = style_loss(style_gram_matrices,
                                                   stylized_end_points,
                                                   style_weights)

    # Compute the total loss
    loss = total_content_loss + total_style_loss

    loss_dict = {'total_loss': loss}
    loss_dict.update(content_loss_dict)
    loss_dict.update(style_loss_dict)

    return loss, loss_dict
Exemple #4
0
def total_loss(inputs, stylized_inputs, style_gram_matrices, content_weights,
               style_weights, reuse=False):
  """Computes the total loss function.

  The total loss function is composed of a content, a style and a total
  variation term.

  Args:
    inputs: Tensor. The input images.
    stylized_inputs: Tensor. The stylized input images.
    style_gram_matrices: dict mapping layer names to their corresponding
        Gram matrices.
    content_weights: dict mapping layer names to their associated content loss
        weight. Keys that are missing from the dict won't have their content
        loss computed.
    style_weights: dict mapping layer names to their associated style loss
        weight. Keys that are missing from the dict won't have their style
        loss computed.
    reuse: bool. Whether to reuse model parameters. Defaults to False.

  Returns:
    Tensor for the total loss, dict mapping loss names to losses.
  """
  # Propagate the the input and its stylized version through VGG16
  end_points = vgg.vgg_16(inputs, reuse=reuse)
  stylized_end_points = vgg.vgg_16(stylized_inputs, reuse=True)

  # Compute the content loss
  total_content_loss, content_loss_dict = content_loss(
      end_points, stylized_end_points, content_weights)

  # Compute the style loss
  total_style_loss, style_loss_dict = style_loss(
      style_gram_matrices, stylized_end_points, style_weights)

  # Compute the total loss
  loss = total_content_loss + total_style_loss

  loss_dict = {'total_loss': loss}
  loss_dict.update(content_loss_dict)
  loss_dict.update(style_loss_dict)

  return loss, loss_dict
def precompute_gram_matrices(image, final_endpoint='fc8'):
  """Pre-computes the Gram matrices on a given image.

  Args:
    image: 4-D tensor. Input (batch of) image(s).
    final_endpoint: str, name of the final layer to compute Gram matrices for.
        Defaults to 'fc8'.

  Returns:
    dict mapping layer names to their corresponding Gram matrices.
  """
  with tf.Session() as session:
    end_points = vgg.vgg_16(image, final_endpoint=final_endpoint)
    tf.train.Saver(slim.get_variables('vgg_16')).restore(
        session, vgg.checkpoint_file())
    return dict((key, gram_matrix(value).eval())
                for key, value in end_points.items())
Exemple #6
0
def precompute_gram_matrices(image, final_endpoint='fc8'):
  """Pre-computes the Gram matrices on a given image.

  Args:
    image: 4-D tensor. Input (batch of) image(s).
    final_endpoint: str, name of the final layer to compute Gram matrices for.
        Defaults to 'fc8'.

  Returns:
    dict mapping layer names to their corresponding Gram matrices.
  """
  with tf.Session() as session:
    end_points = vgg.vgg_16(image, final_endpoint=final_endpoint)
    tf.train.Saver(slim.get_variables('vgg_16')).restore(
        session, vgg.checkpoint_file())
    return dict([(key, _gram_matrix(value).eval())
                 for key, value in end_points.iteritems()])