def precompute_gram_matrices(image, final_endpoint='fc8'):
    """Pre-computes the Gram matrices on a given image.

  Args:
    image: 4-D tensor. Input (batch of) image(s).
    final_endpoint: str, name of the final layer to compute Gram matrices for.
        Defaults to 'fc8'.

  Returns:
    dict mapping layer names to their corresponding Gram matrices.
  """
    with tf.Session() as session:
        end_points = vgg.vgg_16(image, final_endpoint=final_endpoint)
        tf.train.Saver(slim.get_variables('vgg_16')).restore(
            session, vgg.checkpoint_file())
        return dict([(key, gram_matrix(value).eval())
                     for key, value in end_points.iteritems()])
Пример #2
0
def main(unused_argv=None):
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    # Forces all input processing onto CPU in order to reserve the GPU for the
    # forward inference and back-propagation.
    device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
    with tf.device(
        tf.train.replica_device_setter(FLAGS.ps_tasks, worker_device=device)):
      # Loads content images.
      content_inputs_, _ = image_utils.imagenet_inputs(FLAGS.batch_size,
                                                       FLAGS.image_size)

      # Loads style images.
      [style_inputs_, _,
       style_inputs_orig_] = image_utils.arbitrary_style_image_inputs(
           FLAGS.style_dataset_file,
           batch_size=FLAGS.batch_size,
           image_size=FLAGS.image_size,
           shuffle=True,
           center_crop=FLAGS.center_crop,
           augment_style_images=FLAGS.augment_style_images,
           random_style_image_size=FLAGS.random_style_image_size)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      # Process style and content weight flags.
      content_weights = ast.literal_eval(FLAGS.content_weights)
      style_weights = ast.literal_eval(FLAGS.style_weights)

      # Define the model
      stylized_images, total_loss, loss_dict, _ = build_model.build_model(
          content_inputs_,
          style_inputs_,
          trainable=True,
          is_training=True,
          inception_end_point='Mixed_6e',
          style_prediction_bottleneck=100,
          adds_losses=True,
          content_weights=content_weights,
          style_weights=style_weights,
          total_variation_weight=FLAGS.total_variation_weight)

      # Adding scalar summaries to the tensorboard.
      for key, value in loss_dict.iteritems():
        tf.summary.scalar(key, value)

      # Adding Image summaries to the tensorboard.
      tf.summary.image('image/0_content_inputs', content_inputs_, 3)
      tf.summary.image('image/1_style_inputs_orig', style_inputs_orig_, 3)
      tf.summary.image('image/2_style_inputs_aug', style_inputs_, 3)
      tf.summary.image('image/3_stylized_images', stylized_images, 3)

      # Set up training
      optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
      train_op = slim.learning.create_train_op(
          total_loss,
          optimizer,
          clip_gradient_norm=FLAGS.clip_gradient_norm,
          summarize_gradients=False)

      # Function to restore VGG16 parameters.
      init_fn_vgg = slim.assign_from_checkpoint_fn(vgg.checkpoint_file(),
                                                   slim.get_variables('vgg_16'))

      # Function to restore Inception_v3 parameters.
      inception_variables_dict = {
          var.op.name: var
          for var in slim.get_model_variables('InceptionV3')
      }
      init_fn_inception = slim.assign_from_checkpoint_fn(
          FLAGS.inception_v3_checkpoint, inception_variables_dict)

      # Function to restore VGG16 and Inception_v3 parameters.
      def init_sub_networks(session):
        init_fn_vgg(session)
        init_fn_inception(session)

      # Run training
      slim.learning.train(
          train_op=train_op,
          logdir=os.path.expanduser(FLAGS.train_dir),
          master=FLAGS.master,
          is_chief=FLAGS.task == 0,
          number_of_steps=FLAGS.train_steps,
          init_fn=init_sub_networks,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
 def init_fn(session):
     saver.restore(session, vgg.checkpoint_file())