def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  style_files = _parse_style_files(os.path.expanduser(FLAGS.style_files))
  with tf.python_io.TFRecordWriter(
      os.path.expanduser(FLAGS.output_file)) as writer:
    for style_label, style_file in enumerate(style_files):
      tf.logging.info(
          'Processing style file %s: %s' % (style_label, style_file))
      feature = {'label': _int64_feature(style_label)}

      style_image = image_utils.load_np_image(style_file)
      buf = io.BytesIO()
      skimage.io.imsave(buf, style_image, format='JPEG')
      buf.seek(0)
      feature['image_raw'] = _bytes_feature(buf.getvalue())

      if FLAGS.compute_gram_matrices:
        with tf.Graph().as_default():
          style_end_points = learning.precompute_gram_matrices(
              tf.expand_dims(tf.to_float(style_image), 0),
              # We use 'pool5' instead of 'fc8' because a) fully-connected
              # layers are already too deep in the network to be useful for
              # style and b) they're quite expensive to store.
              final_endpoint='pool5')
          for name in style_end_points:
            feature[name] = _float_feature(
                style_end_points[name].flatten().tolist())

      example = tf.train.Example(features=tf.train.Features(feature=feature))
      writer.write(example.SerializeToString())
  tf.logging.info('Output TFRecord file is saved at %s' % os.path.expanduser(
      FLAGS.output_file))
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  style_files = _parse_style_files(os.path.expanduser(FLAGS.style_files))
  with tf.python_io.TFRecordWriter(
      os.path.expanduser(FLAGS.output_file)) as writer:
    for style_label, style_file in enumerate(style_files):
      tf.logging.info(
          'Processing style file %s: %s' % (style_label, style_file))
      feature = {'label': _int64_feature(style_label)}

      style_image = image_utils.load_np_image(style_file)
      buf = io.BytesIO()
      scipy.misc.imsave(buf, style_image, format='JPEG')
      buf.seek(0)
      feature['image_raw'] = _bytes_feature(buf.getvalue())

      if FLAGS.compute_gram_matrices:
        with tf.Graph().as_default():
          style_end_points = learning.precompute_gram_matrices(
              tf.expand_dims(tf.to_float(style_image), 0),
              # We use 'pool5' instead of 'fc8' because a) fully-connected
              # layers are already too deep in the network to be useful for
              # style and b) they're quite expensive to store.
              final_endpoint='pool5')
          for name, matrix in style_end_points.iteritems():
            feature[name] = _float_feature(matrix.flatten().tolist())

      example = tf.train.Example(features=tf.train.Features(feature=feature))
      writer.write(example.SerializeToString())
  tf.logging.info('Output TFRecord file is saved at %s' % os.path.expanduser(
      FLAGS.output_file))