def _multiple_images(input_image, which_styles, output_dir):
  """Stylizes an image into a set of styles and writes them to disk."""
  with tf.Graph().as_default(), tf.Session() as sess:
    stylized_images = model.transform(
        tf.concat([input_image for _ in range(len(which_styles))], 0),
        normalizer_params={
            'labels': tf.constant(which_styles),
            'num_categories': FLAGS.num_styles,
            'center': True,
            'scale': True})
    _load_checkpoint(sess, FLAGS.checkpoint)
    
    ops = sess.graph.get_operations()
    for op in ops:
        print(op.name)
    
    print("Parameters")
    for v in slim.get_model_variables():
        print('name = {}, shape = {}'.format(v.name, v.get_shape()))
    
    save_graph_to_file(sess,sess.graph_def ,"./tmp/image_stylization/mine_freeze_graph.pb") 
    writer =tf.summary.FileWriter("./tmp/logs/",graph = sess.graph)
    writer.close()

    stylized_images = stylized_images.eval()
    for which, stylized_image in zip(which_styles, stylized_images):
        generated_file = '{}/{}_{}.jpg'.format(output_dir, FLAGS.output_basename, which)
        
        with tf.gfile.GFile(generated_file, 'wb') as f:
            stylized_image = tf.cast(stylized_image*255, tf.uint8)
            stylized_encode = sess.run(tf.image.encode_jpeg(stylized_image))
            print(generated_file)
            f.write(stylized_encode)
def _multiple_styles(input_image, which_styles, output_dir):
  """Stylizes image into a linear combination of styles and writes to disk."""
  with tf.Graph().as_default(), tf.Session() as sess:
    mixture = _style_mixture(which_styles, FLAGS.num_styles)
    stylized_images = model.transform(
        input_image,
        normalizer_fn=ops.weighted_instance_norm,
        normalizer_params={
            'weights': tf.constant(mixture),
            'num_categories': FLAGS.num_styles,
            'center': True,
            'scale': True})
    _load_checkpoint(sess, FLAGS.checkpoint)

    
    generated_file = os.path.join(output_dir, '%s_%s.jpg' % (FLAGS.output_basename, _describe_style(which_styles)))
    print(generated_file)
    
    with tf.gfile.GFile(generated_file, 'wb') as f:
        stylized_images = tf.cast(stylized_images*255, tf.uint8)
        stylized_images = tf.squeeze(stylized_images, axis=0)
        stylized_encode = sess.run(tf.image.encode_jpeg(stylized_images))
        f.write(stylized_encode)
    num_styles = 7  # Number of images in checkpoint file. Do not change.

# Styles from checkpoint file to render. They are done in batch, so the more
# rendered, the longer it will take and the more memory will be used.
# These can be modified as you like. Here we randomly select six styles.
styles = range(num_styles)
# random.shuffle(styles)
which_styles = styles[0:6]
num_rendered = len(which_styles)
print(styles, which_styles)

with tf.Graph().as_default(), tf.Session() as sess:
    stylized_images = model.transform(tf.concat(
        [image for _ in range(len(which_styles))], 0),
                                      normalizer_params={
                                          'labels': tf.constant(which_styles),
                                          'num_categories': num_styles,
                                          'center': True,
                                          'scale': True
                                      })
    model_saver = tf.train.Saver(tf.global_variables())
    model_saver.restore(sess, checkpoint)
    stylized_images = stylized_images.eval()

    # Plot the images.
    counter = 0
    num_cols = 3
    f, axarr = plt.subplots(num_rendered // num_cols,
                            num_cols,
                            figsize=(25, 25))
    for col in range(num_cols):
        for row in range(num_rendered // num_cols):
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)

    with tf.Graph().as_default():
        # Force all input processing onto CPU in order to reserve the GPU for the
        # forward inference and back-propagation.
        device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               worker_device=device)):
            inputs, _ = image_utils.imagenet_inputs(FLAGS.batch_size,
                                                    FLAGS.image_size)
            # Load style images and select one at random (for each graph execution, a
            # new random selection occurs)
            _, style_labels, style_gram_matrices = image_utils.style_image_inputs(
                os.path.expanduser(FLAGS.style_dataset_file),
                batch_size=FLAGS.batch_size,
                image_size=FLAGS.image_size,
                square_crop=True,
                shuffle=True)

        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Process style and weight flags
            num_styles = FLAGS.num_styles
            if FLAGS.style_coefficients is None:
                style_coefficients = [1.0 for _ in range(num_styles)]
            else:
                style_coefficients = ast.literal_eval(FLAGS.style_coefficients)
            if len(style_coefficients) != num_styles:
                raise ValueError(
                    'number of style coefficients differs from number of styles'
                )
            content_weights = ast.literal_eval(FLAGS.content_weights)
            style_weights = ast.literal_eval(FLAGS.style_weights)

            # Rescale style weights dynamically based on the current style image
            style_coefficient = tf.gather(tf.constant(style_coefficients),
                                          style_labels)
            style_weights = dict([(key, style_coefficient * value)
                                  for key, value in style_weights.iteritems()])

            # Define the model
            stylized_inputs = model.transform(inputs,
                                              normalizer_params={
                                                  'labels': style_labels,
                                                  'num_categories': num_styles,
                                                  'center': True,
                                                  'scale': True
                                              })

            # Compute losses.
            total_loss, loss_dict = learning.total_loss(
                inputs, stylized_inputs, style_gram_matrices, content_weights,
                style_weights)
            '''
      inputs: Tensor("batch_processing/Reshape_4:0", shape=(12, 256, 256, 3), dtype=float32) ,content image
      stylized_inputs: Tensor("transformer/expand/conv3/conv/Sigmoid:0", shape=(12, ?, ?, 3), dtype=float32) ,pastiche image
      style_gram_matrices    dict: {}    
        'vgg_16/conv1' ()    Tensor: Tensor("style_image_processing/batch:2", shape=(12, 64, 64), dtype=float32)    
        'vgg_16/conv2' ()    Tensor: Tensor("style_image_processing/batch:4", shape=(12, 128, 128), dtype=float32)    
        'vgg_16/conv3' ()    Tensor: Tensor("style_image_processing/batch:6", shape=(12, 256, 256), dtype=float32)    
        'vgg_16/conv4' ()    Tensor: Tensor("style_image_processing/batch:8", shape=(12, 512, 512), dtype=float32)    
        'vgg_16/conv5' ()    Tensor: Tensor("style_image_processing/batch:10", shape=(12, 512, 512), dtype=float32)    
        'vgg_16/pool1' ()    Tensor: Tensor("style_image_processing/batch:3", shape=(12, 64, 64), dtype=float32)    
        'vgg_16/pool2' ()    Tensor: Tensor("style_image_processing/batch:5", shape=(12, 128, 128), dtype=float32)    
        'vgg_16/pool3' ()    Tensor: Tensor("style_image_processing/batch:7", shape=(12, 256, 256), dtype=float32)    
        'vgg_16/pool4' ()    Tensor: Tensor("style_image_processing/batch:9", shape=(12, 512, 512), dtype=float32)    
        'vgg_16/pool5' ()    Tensor: Tensor("style_image_processing/batch:11", shape=(12, 512, 512), dtype=float32)
      content_weights    dict: {}    
        'vgg_16/conv3' ()    float: 1.0
      style_weights    dict: {}    
        'vgg_16/conv1' ()    Tensor: Tensor("mul:0", shape=(12,), dtype=float32)    
        'vgg_16/conv2' ()    Tensor: Tensor("mul_2:0", shape=(12,), dtype=float32)    
        'vgg_16/conv3' ()    Tensor: Tensor("mul_1:0", shape=(12,), dtype=float32)    
        'vgg_16/conv4' ()    Tensor: Tensor("mul_3:0", shape=(12,), dtype=float32)       
      '''
            for key, value in loss_dict.iteritems():
                tf.summary.scalar(key, value)

            # Set up training
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            train_op = slim.learning.create_train_op(
                total_loss,
                optimizer,
                clip_gradient_norm=FLAGS.clip_gradient_norm,
                summarize_gradients=False)

            # Function to restore VGG16 parameters
            # TODO(iansimon): This is ugly, but assign_from_checkpoint_fn doesn't
            # exist yet.
            saver = tf.train.Saver(slim.get_variables('vgg_16'))

            def init_fn(session):
                saver.restore(session, vgg.checkpoint_file())

            # Run training
            slim.learning.train(train_op=train_op,
                                logdir=os.path.expanduser(FLAGS.train_dir),
                                log_every_n_steps=FLAGS.log_steps,
                                master=FLAGS.master,
                                is_chief=FLAGS.task == 0,
                                number_of_steps=FLAGS.train_steps,
                                init_fn=init_fn,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv=None):
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    # Force all input processing onto CPU in order to reserve the GPU for the
    # forward inference and back-propagation.
    device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks,
                                                  worker_device=device)):
      inputs, _ = image_utils.imagenet_inputs(FLAGS.batch_size,
                                              FLAGS.image_size)
      # Load style images and select one at random (for each graph execution, a
      # new random selection occurs)
      _, style_labels, style_gram_matrices = image_utils.style_image_inputs(
          os.path.expanduser(FLAGS.style_dataset_file),
          batch_size=FLAGS.batch_size, image_size=FLAGS.image_size,
          square_crop=True, shuffle=True)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      # Process style and weight flags
      num_styles = FLAGS.num_styles
      if FLAGS.style_coefficients is None:
        style_coefficients = [1.0 for _ in range(num_styles)]
      else:
        style_coefficients = ast.literal_eval(FLAGS.style_coefficients)
      if len(style_coefficients) != num_styles:
        raise ValueError(
            'number of style coefficients differs from number of styles')
      content_weights = ast.literal_eval(FLAGS.content_weights)
      style_weights = ast.literal_eval(FLAGS.style_weights)

      # Rescale style weights dynamically based on the current style image
      style_coefficient = tf.gather(
          tf.constant(style_coefficients), style_labels)
      style_weights = dict([(key, style_coefficient * value)
                            for key, value in style_weights.iteritems()])

      # Define the model
      stylized_inputs = model.transform(
          inputs,
          normalizer_params={
              'labels': style_labels,
              'num_categories': num_styles,
              'center': True,
              'scale': True})

      # Compute losses.
      total_loss, loss_dict = learning.total_loss(
          inputs, stylized_inputs, style_gram_matrices, content_weights,
          style_weights)
      for key, value in loss_dict.iteritems():
        tf.summary.scalar(key, value)

      instance_norm_vars = [var for var in slim.get_variables('transformer')
                            if 'InstanceNorm' in var.name]
      other_vars = [var for var in slim.get_variables('transformer')
                    if 'InstanceNorm' not in var.name]

      # Function to restore VGG16 parameters.
      # TODO(iansimon): This is ugly, but assign_from_checkpoint_fn doesn't
      # exist yet.
      saver_vgg = tf.train.Saver(slim.get_variables('vgg_16'))
      def init_fn_vgg(session):
        saver_vgg.restore(session, vgg.checkpoint_file())

      # Function to restore N-styles parameters.
      # TODO(iansimon): This is ugly, but assign_from_checkpoint_fn doesn't
      # exist yet.
      saver_n_styles = tf.train.Saver(other_vars)
      def init_fn_n_styles(session):
        saver_n_styles.restore(session, os.path.expanduser(FLAGS.checkpoint))

      def init_fn(session):
        init_fn_vgg(session)
        init_fn_n_styles(session)

      # Set up training.
      optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
      train_op = slim.learning.create_train_op(
          total_loss, optimizer, clip_gradient_norm=FLAGS.clip_gradient_norm,
          variables_to_train=instance_norm_vars, summarize_gradients=False)

      # Run training.
      slim.learning.train(
          train_op=train_op,
          logdir=os.path.expanduser(FLAGS.train_dir),
          log_every_n_steps=FLAGS.log_steps,
          master=FLAGS.master,
          is_chief=FLAGS.task == 0,
          number_of_steps=FLAGS.train_steps,
          init_fn=init_fn,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)