def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) style_files = _parse_style_files(os.path.expanduser(FLAGS.style_files)) with tf.python_io.TFRecordWriter( os.path.expanduser(FLAGS.output_file)) as writer: for style_label, style_file in enumerate(style_files): tf.logging.info( 'Processing style file %s: %s' % (style_label, style_file)) feature = {'label': _int64_feature(style_label)} style_image = image_utils.load_np_image(style_file) buf = io.BytesIO() skimage.io.imsave(buf, style_image, format='JPEG') buf.seek(0) feature['image_raw'] = _bytes_feature(buf.getvalue()) if FLAGS.compute_gram_matrices: with tf.Graph().as_default(): style_end_points = learning.precompute_gram_matrices( tf.expand_dims(tf.to_float(style_image), 0), # We use 'pool5' instead of 'fc8' because a) fully-connected # layers are already too deep in the network to be useful for # style and b) they're quite expensive to store. final_endpoint='pool5') for name in style_end_points: feature[name] = _float_feature( style_end_points[name].flatten().tolist()) example = tf.train.Example(features=tf.train.Features(feature=feature)) writer.write(example.SerializeToString()) tf.logging.info('Output TFRecord file is saved at %s' % os.path.expanduser( FLAGS.output_file))
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) style_files = _parse_style_files(os.path.expanduser(FLAGS.style_files)) with tf.python_io.TFRecordWriter( os.path.expanduser(FLAGS.output_file)) as writer: for style_label, style_file in enumerate(style_files): tf.logging.info( 'Processing style file %s: %s' % (style_label, style_file)) feature = {'label': _int64_feature(style_label)} style_image = image_utils.load_np_image(style_file) buf = io.BytesIO() scipy.misc.imsave(buf, style_image, format='JPEG') buf.seek(0) feature['image_raw'] = _bytes_feature(buf.getvalue()) if FLAGS.compute_gram_matrices: with tf.Graph().as_default(): style_end_points = learning.precompute_gram_matrices( tf.expand_dims(tf.to_float(style_image), 0), # We use 'pool5' instead of 'fc8' because a) fully-connected # layers are already too deep in the network to be useful for # style and b) they're quite expensive to store. final_endpoint='pool5') for name, matrix in style_end_points.iteritems(): feature[name] = _float_feature(matrix.flatten().tolist()) example = tf.train.Example(features=tf.train.Features(feature=feature)) writer.write(example.SerializeToString()) tf.logging.info('Output TFRecord file is saved at %s' % os.path.expanduser( FLAGS.output_file))