def total_loss(content_inputs, style_inputs, stylized_inputs, content_weights,
               style_weights, total_variation_weight, reuse=False):
  """Computes the total loss function.

  The total loss function is composed of a content, a style and a total
  variation term.

  Args:
    content_inputs: Tensor. The input images.
    style_inputs: Tensor. The input images.
    stylized_inputs: Tensor. The stylized input images.
    content_weights: dict mapping layer names to their associated content loss
        weight. Keys that are missing from the dict won't have their content
        loss computed.
    style_weights: dict mapping layer names to their associated style loss
        weight. Keys that are missing from the dict won't have their style
        loss computed.
    total_variation_weight: float. Coefficient for the total variation part of
        the loss.
    reuse: bool. Whether to reuse model parameters. Defaults to False.

  Returns:
    Tensor for the total loss, dict mapping loss names to losses.
  """
  # Propagate the input and its stylized version through VGG16.
  with tf.name_scope('content_endpoints'):
    content_end_points = vgg.vgg_16(content_inputs, reuse=reuse)
  with tf.name_scope('style_endpoints'):
    style_end_points = vgg.vgg_16(style_inputs, reuse=True)
  with tf.name_scope('stylized_endpoints'):
    stylized_end_points = vgg.vgg_16(stylized_inputs, reuse=True)

  # Compute the content loss
  with tf.name_scope('content_loss'):
    total_content_loss, content_loss_dict = content_loss(
        content_end_points, stylized_end_points, content_weights)

  # Compute the style loss
  with tf.name_scope('style_loss'):
    total_style_loss, style_loss_dict = style_loss(
        style_end_points, stylized_end_points, style_weights)

  # Compute the total variation loss
  with tf.name_scope('total_variation_loss'):
    tv_loss, total_variation_loss_dict = learning_utils.total_variation_loss(
        stylized_inputs, total_variation_weight)

  # Compute the total loss
  with tf.name_scope('total_loss'):
    loss = total_content_loss + total_style_loss + tv_loss

  loss_dict = {'total_loss': loss}
  loss_dict.update(content_loss_dict)
  loss_dict.update(style_loss_dict)
  loss_dict.update(total_variation_loss_dict)

  return loss, loss_dict
예제 #2
0
def total_loss(content_inputs, style_inputs, stylized_inputs, content_weights,
               style_weights, total_variation_weight, reuse=False):
  """Computes the total loss function.

  The total loss function is composed of a content, a style and a total
  variation term.

  Args:
    content_inputs: Tensor. The input images.
    style_inputs: Tensor. The input images.
    stylized_inputs: Tensor. The stylized input images.
    content_weights: dict mapping layer names to their associated content loss
        weight. Keys that are missing from the dict won't have their content
        loss computed.
    style_weights: dict mapping layer names to their associated style loss
        weight. Keys that are missing from the dict won't have their style
        loss computed.
    total_variation_weight: float. Coefficient for the total variation part of
        the loss.
    reuse: bool. Whether to reuse model parameters. Defaults to False.

  Returns:
    Tensor for the total loss, dict mapping loss names to losses.
  """
  # Propagate the input and its stylized version through VGG16.
  with tf.name_scope('content_endpoints'):
    content_end_points = vgg.vgg_16(content_inputs, reuse=reuse)
  with tf.name_scope('style_endpoints'):
    style_end_points = vgg.vgg_16(style_inputs, reuse=True)
  with tf.name_scope('stylized_endpoints'):
    stylized_end_points = vgg.vgg_16(stylized_inputs, reuse=True)

  # Compute the content loss
  with tf.name_scope('content_loss'):
    total_content_loss, content_loss_dict = content_loss(
        content_end_points, stylized_end_points, content_weights)

  # Compute the style loss
  with tf.name_scope('style_loss'):
    total_style_loss, style_loss_dict = style_loss(
        style_end_points, stylized_end_points, style_weights)

  # Compute the total variation loss
  with tf.name_scope('total_variation_loss'):
    tv_loss, total_variation_loss_dict = learning_utils.total_variation_loss(
        stylized_inputs, total_variation_weight)

  # Compute the total loss
  with tf.name_scope('total_loss'):
    loss = total_content_loss + total_style_loss + tv_loss

  loss_dict = {'total_loss': loss}
  loss_dict.update(content_loss_dict)
  loss_dict.update(style_loss_dict)
  loss_dict.update(total_variation_loss_dict)

  return loss, loss_dict