def generate_pb(ckpt_dir, filename, device_t, batch_shape, sess): g = tf.Graph() soft_config = tf.ConfigProto(allow_soft_placement=True) soft_config.gpu_options.allow_growth = True with g.as_default(), device_t, tf.Session(config=soft_config) as sess: # batch_shape = (batch_size,) + img_shape img_placeholder = tf.placeholder(tf.float32, shape=batch_shape, name='img_placeholder') preds = transform.net(img_placeholder) saver = tf.train.Saver() if os.path.isdir(ckpt_dir): ckpt = tf.train.get_checkpoint_state(ckpt_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, ['add_37']) with open(ckpt_dir + '/' + filename + '.pb', 'wb') as f: f.write(frozen_graph_def.SerializeToString()) else: raise Exception("No checkpoint found...") else: saver.restore(sess, ckpt_dir) frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, ['add_37']) with open(ckpt_dir + '/' + filename + '.pb', 'wb') as f: f.write(frozen_graph_def.SerializeToString())
def main(): options = parser.parse_args() with tf.Graph().as_default(): img_shape = (options.height, options.width, 3) batch_shape = (1, ) + img_shape img_placeholder = tf.placeholder(tf.float32, shape=batch_shape, name='input') preds = transform.net(img_placeholder) output = tf.identity(preds, name='output') with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver = tf.train.Saver(tf.global_variables()) if os.path.isdir(options.checkpoint_dir): ckpt = tf.train.get_checkpoint_state(options.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise Exception("No checkpoint found...") else: saver.restore(sess, options.checkpoint_dir) out_path = os.path.join(options.out_dir, 'ncs') saver.save(sess, out_path) print("Saved to:", out_path + '.meta') print('Input tensor name:', img_placeholder.name) print('Output tensor name:', output.name)
def ffwd_video(path_in, path_out, checkpoint_dir, device_t='/gpu:0', batch_size=4): video_clip = VideoFileClip(path_in, audio=False) video_writer = ffmpeg_writer.FFMPEG_VideoWriter(path_out, video_clip.size, video_clip.fps, codec="libx264", preset="medium", bitrate="2000k", audiofile=path_in, threads=None, ffmpeg_params=None) g = tf.Graph() soft_config = tf.ConfigProto(allow_soft_placement=True) soft_config.gpu_options.allow_growth = True with g.as_default(), g.device(device_t), \ tf.Session(config=soft_config) as sess: batch_shape = (batch_size, video_clip.size[1], video_clip.size[0], 3) img_placeholder = tf.placeholder(tf.float32, shape=batch_shape, name='img_placeholder') preds = transform.net(img_placeholder) saver = tf.train.Saver() if os.path.isdir(checkpoint_dir): ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise Exception("No checkpoint found...") else: saver.restore(sess, checkpoint_dir) X = np.zeros(batch_shape, dtype=np.float32) def style_and_write(count): for i in range(count, batch_size): X[i] = X[count - 1] # Use last frame to fill X _preds = sess.run(preds, feed_dict={img_placeholder: X}) for i in range(0, count): video_writer.write_frame( np.clip(_preds[i], 0, 255).astype(np.uint8)) frame_count = 0 # The frame count that written to X for frame in video_clip.iter_frames(): X[frame_count] = frame frame_count += 1 if frame_count == batch_size: style_and_write(frame_count) frame_count = 0 if frame_count != 0: style_and_write(frame_count) video_writer.close()
def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4): # is_paths = type(data_in[0]) == str # if is_paths: # assert len(data_in) == len(paths_out) img_shape = get_img(data_in[0]).shape # else: # assert data_in.size[0] == len(paths_out) # img_shape = X[0].shape g = tf.Graph() batch_size = min(len(paths_out), batch_size) curr_num = 0 soft_config = tf.ConfigProto(allow_soft_placement=True) soft_config.gpu_options.allow_growth = True with g.as_default(), g.device(device_t), tf.Session( config=soft_config) as sess: batch_shape = (batch_size, ) + img_shape img_placeholder = tf.placeholder(tf.float32, shape=batch_shape, name='img_placeholder') preds = transform.net(img_placeholder) saver = tf.train.Saver() if os.path.isdir(checkpoint_dir): ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise Exception("No checkpoint found...") else: saver.restore(sess, checkpoint_dir) num_iters = int(len(paths_out) / batch_size) for i in range(num_iters): pos = i * batch_size curr_batch_out = paths_out[pos:pos + batch_size] # if is_paths: curr_batch_in = data_in[pos:pos + batch_size] X = np.zeros(batch_shape, dtype=np.float32) for j, path_in in enumerate(curr_batch_in): img = get_img(path_in) assert img.shape == img_shape X[j] = img # else: # X = data_in[pos:pos+batch_size] _preds = sess.run(preds, feed_dict={img_placeholder: X}) for j, path_out in enumerate(curr_batch_out): save_img(path_out, _preds[j]) remaining_in = data_in[num_iters * batch_size:] remaining_out = paths_out[num_iters * batch_size:] if len(remaining_in) > 0: ffwd(remaining_in, remaining_out, checkpoint_dir, device_t=device_t, batch_size=1)
def ffwd(content, network_path): with tf.Session() as sess: content_placeholder = tf.placeholder(tf.float32, shape=content.shape, name='content_placeholder') network = net(content_placeholder / 255.0) #scale the image from 0-1 saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(network_path) saver.restore(sess, ckpt.model_checkpoint_path) prediction = sess.run(network, feed_dict={content_placeholder: content}) return prediction[0]
def main(): parser = build_parser() options = parser.parse_args() ckpt_dir = options.checkpoint_dir filename = options.file_name g = tf.Graph() soft_config = tf.ConfigProto(allow_soft_placement=True) soft_config.gpu_options.allow_growth = True with g.as_default(), tf.device('/cpu:0'), tf.Session( config=soft_config) as sess: # batch_shape = (batch_size,) + img_shape img_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 256, 256, 3), name='img_placeholder') preds = transform.net(img_placeholder) saver = tf.train.Saver() if os.path.isdir(ckpt_dir): ckpt = tf.train.get_checkpoint_state(ckpt_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, ['add_37']) with open(ckpt_dir + '/' + filename + '.pb', 'wb') as f: f.write(frozen_graph_def.SerializeToString()) else: raise Exception("No checkpoint found...") else: saver.restore(sess, ckpt_dir) frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, ['add_37']) with open(ckpt_dir + '/' + filename + '.pb', 'wb') as f: f.write(frozen_graph_def.SerializeToString()) tf_converter.convert(tf_model_path=ckpt_dir + '/' + filename + '.pb', mlmodel_path=ckpt_dir + '/' + filename + '.mlmodel', output_feature_names=['add_37:0'], image_input_names=['img_placeholder__0']) model = coremltools.models.MLModel(ckpt_dir + '/' + filename + '.mlmodel') # lin_quant_model = quantize_weights(model, 8, "linear") # lin_quant_model.save(ckpt_dir + '/' + filename + '.mlmodel') spec = model.get_spec() convert_multiarray_output_to_image(spec, 'add_37__0', is_bgr=False) new_model = coremltools.models.MLModel(spec) new_model.save(ckpt_dir + '/' + filename + '_output.mlmodel') convert_flexible_coremodel(ckpt_dir + '/' + filename + '_output.mlmodel', 'img_placeholder__0', 'add_37__0')
def __init__(self, checkpoint_dir, img_shape, device_t): soft_config = tf.ConfigProto(allow_soft_placement=True) soft_config.gpu_options.allow_growth = True # soft_config.log_device_placement = True self.sess = tf.Session(config=soft_config) batch_shape = (1, ) + img_shape self.img_placeholder = tf.placeholder(tf.float32, shape=batch_shape, name='img_placeholder') with tf.device(device_t): self.preds = transform.net(self.img_placeholder / 255.) saver = tf.train.Saver() if os.path.isdir(checkpoint_dir): ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(self.sess, ckpt.model_checkpoint_path) else: raise Exception("No checkpoint found...") else: saver.restore(self.sess, checkpoint_dir)
def get_output(image_as_array, checkpoint_dir): g = tf.Graph() soft_config = tf.ConfigProto(allow_soft_placement=True) soft_config.gpu_options.allow_growth = True with g.as_default(), g.device("cpu:0"), tf.Session(config=soft_config) as sess: batch_shape = (1,) + image_as_array.shape batch = np.zeros(batch_shape, dtype=np.float32) batch[0,:,:,:] = image_as_array img_placeholder = tf.placeholder(tf.float32, shape=batch_shape, name='img_placeholder') preds = transform.net(img_placeholder) saver = tf.train.Saver() if os.path.isdir(checkpoint_dir): ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise Exception("No checkpoint found...") else: saver.restore(sess, checkpoint_dir) _preds = sess.run(preds, feed_dict={img_placeholder: batch}) return _preds[0]
def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4): start_ffwd = time.time() assert len(paths_out) > 0 is_paths = type(data_in[0]) == str if is_paths: assert len(data_in) == len(paths_out) img_shape = get_img(data_in[0]).shape else: assert data_in.size[0] == len(paths_out) img_shape = X[0].shape g = tf.Graph() batch_size = min(len(paths_out), batch_size) soft_config = tf.compat.v1.ConfigProto(allow_soft_placement=True) soft_config.gpu_options.allow_growth = True with g.as_default(), g.device(device_t), \ tf.compat.v1.Session(config=soft_config) as sess: batch_shape = (batch_size, ) + img_shape img_placeholder = tf.compat.v1.placeholder(tf.float32, shape=batch_shape, name='img_placeholder') preds = transform.net(img_placeholder) saver = tf.compat.v1.train.Saver() if os.path.isdir(checkpoint_dir): ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise Exception("No checkpoint found...") else: saver.restore(sess, checkpoint_dir) num_iters = int(len(paths_out) / batch_size) for i in range(num_iters): pos = i * batch_size curr_batch_out = paths_out[pos:pos + batch_size] if is_paths: curr_batch_in = data_in[pos:pos + batch_size] X = np.zeros(batch_shape, dtype=np.float32) for j, path_in in enumerate(curr_batch_in): img = get_img(path_in) assert img.shape == img_shape, \ 'Images have different dimensions. ' + \ 'Resize images or use --allow-different-dimensions.' X[j] = img else: X = data_in[pos:pos + batch_size] _preds = sess.run(preds, feed_dict={img_placeholder: X}) for j, path_out in enumerate(curr_batch_out): save_img(path_out, _preds[j]) remaining_in = data_in[num_iters * batch_size:] remaining_out = paths_out[num_iters * batch_size:] if len(remaining_in) > 0: ffwd(remaining_in, remaining_out, checkpoint_dir, device_t=device_t, batch_size=1) time_needed = time.time() - start_ffwd print("ffwd function worked {:.4} seconds, file={} shape={}".format( time_needed, 0, img_shape))
def from_pipe(opts): command = [ "ffprobe", '-v', "quiet", '-print_format', 'json', '-show_streams', opts.in_path ] info = json.loads(str(subprocess.check_output(command), encoding='utf8')) width = int(info["streams"][0]["width"]) height = int(info["streams"][0]["height"]) fps = round(eval(info["streams"][0]["r_frame_rate"])) command = [ "ffmpeg", '-loglevel', "quiet", '-i', opts.in_path, '-f', 'image2pipe', '-pix_fmt', 'rgb24', '-vcodec', 'rawvideo', '-' ] pipe_in = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=10**9, stdin=None, stderr=None) command = [ "ffmpeg", '-loglevel', "info", '-y', # (optional) overwrite output file if it exists '-f', 'rawvideo', '-vcodec', 'rawvideo', '-s', str(width) + 'x' + str(height), # size of one frame '-pix_fmt', 'rgb24', '-r', str(fps), # frames per second '-i', '-', # The imput comes from a pipe '-an', # Tells FFMPEG not to expect any audio '-c:v', 'libx264', '-preset', 'slow', '-crf', '18', opts.out ] pipe_out = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=None, stderr=None) g = tf.Graph() soft_config = tf.ConfigProto(allow_soft_placement=True) soft_config.gpu_options.allow_growth = True with g.as_default(), g.device(opts.device), \ tf.Session(config=soft_config) as sess: batch_shape = (opts.batch_size, height, width, 3) img_placeholder = tf.placeholder(tf.float32, shape=batch_shape, name='img_placeholder') preds = transform.net(img_placeholder) saver = tf.train.Saver() if os.path.isdir(opts.checkpoint): ckpt = tf.train.get_checkpoint_state(opts.checkpoint) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise Exception("No checkpoint found...") else: saver.restore(sess, opts.checkpoint) X = np.zeros(batch_shape, dtype=np.float32) nbytes = 3 * width * height read_input = True last = False while read_input: count = 0 while count < opts.batch_size: raw_image = pipe_in.stdout.read(width * height * 3) if len(raw_image) != nbytes: if count == 0: read_input = False else: last = True X = X[:count] batch_shape = (count, height, width, 3) img_placeholder = tf.placeholder( tf.float32, shape=batch_shape, name='img_placeholder') preds = transform.net(img_placeholder) break image = numpy.fromstring(raw_image, dtype='uint8') image = image.reshape((height, width, 3)) X[count] = image count += 1 if read_input: if last: read_input = False _preds = sess.run(preds, feed_dict={img_placeholder: X}) for i in range(0, batch_shape[0]): img = np.clip(_preds[i], 0, 255).astype(np.uint8) try: pipe_out.stdin.write(img) except IOError as err: ffmpeg_error = pipe_out.stderr.read() error = (str(err) + ("\n\nFFMPEG encountered" "the following error while writing file:" "\n\n %s" % ffmpeg_error)) read_input = False print(error) pipe_out.terminate() pipe_in.terminate() pipe_out.stdin.close() pipe_in.stdout.close() del pipe_in del pipe_out
import numpy as np from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants # Loads Checkpoint from `ckpt` folder, converts it to a # TensorFlow SavedModel, ready to serve. g = tf.Graph() soft_config = tf.ConfigProto(allow_soft_placement=True) soft_config.gpu_options.allow_growth = True with g.as_default(), tf.Session(config=soft_config) as sess: img_placeholder = tf.placeholder(tf.float32, shape=(1, 256, 256, 3), name='img_placeholder') preds = transform.net(img_placeholder) saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state('ckpt') saver.restore(sess, ckpt.model_checkpoint_path) # Load image (can use any image, just need one arbitrary run of model) img = get_img('images/input/input_italy.jpg', (256, 256, 3)) X = np.zeros((1, 256, 256, 3), dtype=np.float32) X[0] = img # run _preds = sess.run(preds, feed_dict={img_placeholder: X}) # If you want to freeze your graph instead of outputting a tensorflow # SavedModel, uncomment code and comment out all code below # frozen_graph_def = tf.graph_util.convert_variables_to_constants( # sess,
def optimize(content_targets, style_target, content_weight, style_weight, tv_weight, vgg_path, epochs=2, print_iterations=1000, batch_size=4, save_path='saver/fns.ckpt', slow=False, learning_rate=1e-3, debug=False): if slow: batch_size = 1 mod = len(content_targets) % batch_size if mod > 0: print("Train set has been trimmed slightly..") content_targets = content_targets[:-mod] style_features = {} batch_shape = (batch_size,256,256,3) style_shape = (1,) + style_target.shape print(style_shape) # precompute style features with tf.Graph().as_default(), tf.device('/cpu:0'), tf.Session() as sess: style_image = tf.placeholder(tf.float32, shape=style_shape, name='style_image') style_image_pre = vgg.preprocess(style_image) net = vgg.net(vgg_path, style_image_pre) style_pre = np.array([style_target]) for layer in STYLE_LAYERS: features = net[layer].eval(feed_dict={style_image:style_pre}) features = np.reshape(features, (-1, features.shape[3])) gram = np.matmul(features.T, features) / features.size style_features[layer] = gram with tf.Graph().as_default(), tf.Session() as sess: X_content = tf.placeholder(tf.float32, shape=batch_shape, name="X_content") X_pre = vgg.preprocess(X_content) # precompute content features content_features = {} content_net = vgg.net(vgg_path, X_pre) content_features[CONTENT_LAYER] = content_net[CONTENT_LAYER] if slow: preds = tf.Variable( tf.random_normal(X_content.get_shape()) * 0.256 ) preds_pre = preds else: preds = transform.net(X_content/255.0) preds_pre = vgg.preprocess(preds) net = vgg.net(vgg_path, preds_pre) content_size = _tensor_size(content_features[CONTENT_LAYER])*batch_size assert _tensor_size(content_features[CONTENT_LAYER]) == _tensor_size(net[CONTENT_LAYER]) content_loss = content_weight * (2 * tf.nn.l2_loss( net[CONTENT_LAYER] - content_features[CONTENT_LAYER]) / content_size ) style_losses = [] for style_layer in STYLE_LAYERS: layer = net[style_layer] bs, height, width, filters = map(lambda i:i.value,layer.get_shape()) size = height * width * filters feats = tf.reshape(layer, (bs, height * width, filters)) feats_T = tf.transpose(feats, perm=[0,2,1]) grams = tf.matmul(feats_T, feats) / size style_gram = style_features[style_layer] style_losses.append(2 * tf.nn.l2_loss(grams - style_gram)/style_gram.size) style_loss = style_weight * functools.reduce(tf.add, style_losses) / batch_size # total variation denoising tv_y_size = _tensor_size(preds[:,1:,:,:]) tv_x_size = _tensor_size(preds[:,:,1:,:]) y_tv = tf.nn.l2_loss(preds[:,1:,:,:] - preds[:,:batch_shape[1]-1,:,:]) x_tv = tf.nn.l2_loss(preds[:,:,1:,:] - preds[:,:,:batch_shape[2]-1,:]) tv_loss = tv_weight*2*(x_tv/tv_x_size + y_tv/tv_y_size)/batch_size loss = content_loss + style_loss + tv_loss # overall loss train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss) sess.run(tf.global_variables_initializer()) import random uid = random.randint(1, 100) print("UID: %s" % uid) for epoch in range(epochs): num_examples = len(content_targets) iterations = 0 while iterations * batch_size < num_examples: start_time = time.time() curr = iterations * batch_size step = curr + batch_size X_batch = np.zeros(batch_shape, dtype=np.float32) for j, img_p in enumerate(content_targets[curr:step]): X_batch[j] = get_img(img_p, (256,256,3)).astype(np.float32) iterations += 1 assert X_batch.shape[0] == batch_size feed_dict = { X_content:X_batch } train_step.run(feed_dict=feed_dict) end_time = time.time() delta_time = end_time - start_time if debug: print("UID: %s, batch time: %s" % (uid, delta_time)) is_print_iter = int(iterations) % print_iterations == 0 if slow: is_print_iter = epoch % print_iterations == 0 is_last = epoch == epochs - 1 and iterations * batch_size >= num_examples should_print = is_print_iter or is_last if should_print: to_get = [style_loss, content_loss, tv_loss, loss, preds] test_feed_dict = { X_content:X_batch } tup = sess.run(to_get, feed_dict = test_feed_dict) _style_loss,_content_loss,_tv_loss,_loss,_preds = tup losses = (_style_loss, _content_loss, _tv_loss, _loss) if slow: _preds = vgg.unprocess(_preds) else: saver = tf.train.Saver() res = saver.save(sess, save_path) yield(_preds, losses, iterations, epoch)
def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4): assert len(paths_out) > 0 is_paths = type(data_in[0]) == str if is_paths: assert len(data_in) == len(paths_out) img_shape = get_img(data_in[0]).shape else: assert data_in.size[0] == len(paths_out) img_shape = X[0].shape g = tf.Graph() batch_size = min(len(paths_out), batch_size) curr_num = 0 soft_config = tf.ConfigProto(allow_soft_placement=True) soft_config.gpu_options.allow_growth = True with g.as_default(), g.device(device_t), tf.Session( config=soft_config) as sess: batch_shape = (batch_size, ) + img_shape img_placeholder = tf.placeholder(tf.float32, shape=batch_shape, name='img_placeholder') preds = transform.net(img_placeholder) saver = tf.train.Saver() if os.path.isdir(checkpoint_dir): ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt: saver.restore(sess, ckpt) else: os.makedirs("fst_checkpoints", exist_ok=True) ckpt = os.path.dirname("fst_checkpoints") print(ckpt, "variable ckpt status") print("...model checkpoints directory created...") else: saver.restore(sess, checkpoint_dir) num_iters = int(len(paths_out) / batch_size) for i in range(num_iters): pos = i * batch_size curr_batch_out = paths_out[pos:pos + batch_size] if is_paths: curr_batch_in = data_in[pos:pos + batch_size] X = np.zeros(batch_shape, dtype=np.float32) for j, path_in in enumerate(curr_batch_in): img = get_img(path_in) assert img.shape == img_shape, \ 'Images have different dimensions. ' + \ 'Resize images or use --allow-different-dimensions.' X[j] = img else: X = data_in[pos:pos + batch_size] # to fix error 'tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized # value Variable_47' sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.local_variables_initializer()) _preds = sess.run(preds, feed_dict={img_placeholder: X}) for j, path_out in enumerate(curr_batch_out): save_img(path_out, _preds[j]) remaining_in = data_in[num_iters * batch_size:] remaining_out = paths_out[num_iters * batch_size:] if len(remaining_in) > 0: ffwd(remaining_in, remaining_out, checkpoint_dir, device_t=device_t, batch_size=1)
def export(checkpoint, img_shape): if img_shape is None: img_shape = [256, 256, 3] # placeholder for base64 string decoded to an png image input = tf.placeholder(tf.string, shape=[1]) input_data = tf.decode_base64(input[0]) input_image = tf.image.decode_png(input_data) # remove alpha channel if present input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 4), lambda: input_image[:, :, :3], lambda: input_image) # convert grayscale to RGB input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 1), lambda: tf.image.grayscale_to_rgb(input_image), lambda: input_image) input_image = tf.image.convert_image_dtype(input_image, dtype=tf.float32) input_image.set_shape(img_shape) # expected shape is (1, img_shape) because of batches batch_input = tf.expand_dims(input_image, axis=0) # create network batch_output = transform.net(batch_input) # clip RGB values to the allowed range and cast to uint8 batch_output = tf.clip_by_value(batch_output, 0, 255) batch_output = tf.bitcast(tf.cast(batch_output, tf.int8), tf.uint8) output_data = tf.image.encode_png(batch_output[0]) output = tf.convert_to_tensor([tf.encode_base64(output_data)]) # save inputs and outputs to collection key = tf.placeholder(tf.string, shape=[1]) inputs = {"key": key.name, "input": input.name} tf.add_to_collection("inputs", json.dumps(inputs)) outputs = { "key": tf.identity(key).name, "output": output.name, } tf.add_to_collection("outputs", json.dumps(outputs)) init_op = tf.global_variables_initializer() restore_saver = tf.train.Saver() export_saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) if os.path.isdir(checkpoint): ckpt = tf.train.get_checkpoint_state(checkpoint) if ckpt and ckpt.model_checkpoint_path: restore_saver.restore(sess, ckpt.model_checkpoint_path) else: raise Exception("No checkpoint found...") else: restore_saver.restore(sess, checkpoint) print("exporting model") export_saver.export_meta_graph( filename=os.path.join(a.export, "export.meta")) export_saver.save(sess, os.path.join(a.export, "export"), write_meta_graph=False) return
def optimize(content_targets, style_target, content_weight, style_weight, tv_weight, vgg_path, epochs=2, print_iterations=1000, batch_size=4, save_path='saver/fns.ckpt', slow=False, learning_rate=1e-3, debug=False): #!!!只进行两次前向? if slow: batch_size = 1 mod = len( content_targets) % batch_size #训练图像一次四张batch_size=4,训练集图片数量应为4的倍数 if mod > 0: print("Train set has been trimmed slightly..") #训练集被轻微修剪 content_targets = content_targets[:-mod] #去掉余数 style_features = {} #!!! ''' .shape=(HWC) style_shape=(1,图片垂直尺寸,图片水平尺寸,图片通道数) ''' batch_shape = (batch_size, 256, 256, 3) style_shape = (1, ) + style_target.shape #BHWC print("所训练的style image属性(图片垂直尺寸/图片水平尺寸/图片通道数):" + style_target.shape) # precompute style features预计算的风格特征 with tf.Graph().as_default(), tf.device( '/cpu:0'), tf.Session() as sess: #!!!cpu:0??? ''' tf.placeholder(): dtype:数据类型。常用的是tf.float32,tf.float64等数值类型 shape:数据形状。默认是None,就是一维值,也可以是多维:(batch_size,图片垂直尺寸/图片水平尺寸/图片通道数) name:名称 ''' style_image = tf.placeholder(tf.float32, shape=style_shape, name='style_image') #!!!session style_image_pre = vgg.preprocess(style_image) #1.将风格图标准化!!!图像-均值 net = vgg.net(vgg_path, style_image_pre) #2.风格图标准化后并进入vgg style_pre = np.array([style_target]) #风格图的矩阵形式 ''' 取出vgg中过程的特征图,即不同阶段被卷积后的特征图 ''' for layer in STYLE_LAYERS: #取特定层 features = net[layer].eval(feed_dict={ style_image: style_pre }) #喂style_pre给session即给style_image赋值为style_pre,拿到特征图 features = np.reshape(features, (-1, features.shape[3])) #!!! gram = np.matmul(features.T, features) / features.size #计算gram值!!!!!!!!!!!!! style_features[layer] = gram ''' 取出vgg中过程的内容图,即relu4_2 ''' with tf.Graph().as_default(), tf.Session() as sess: X_content = tf.placeholder(tf.float32, shape=batch_shape, name="X_content") #一次四张内容图 X_pre = vgg.preprocess(X_content) #标准化 # precompute content features content_features = {} content_net = vgg.net(vgg_path, X_pre) content_features[CONTENT_LAYER] = content_net[ CONTENT_LAYER] #relu4_2 batch=4 ''' content_features:不经过生成网络的内容图 preds_pre---net:经过生成网络的内容图,即中间图 ''' if slow: preds = tf.Variable( tf.random_normal(X_content.get_shape()) * 0.256) preds_pre = preds else: preds = transform.net(X_content / 255.0) #归一化,float化,经过生产网络残差网络,也是batch=4 preds_pre = vgg.preprocess(preds) #再经过vgg net = vgg.net(vgg_path, preds_pre) content_size = _tensor_size( content_features[CONTENT_LAYER]) * batch_size assert _tensor_size(content_features[CONTENT_LAYER]) == _tensor_size( net[CONTENT_LAYER]) ''' Loss(Content)内容损失函数 ''' content_loss = content_weight * ( 2 * tf.nn.l2_loss(net[CONTENT_LAYER] - content_features[CONTENT_LAYER]) / content_size) ''' Loss(Style)风格损失函数 grams:经过生成网络的内容图,即中间图进行Gram style_gram:上面算过的不经过生成网络的特征图gram ''' style_losses = [] for style_layer in STYLE_LAYERS: layer = net[style_layer] bs, height, width, filters = map(lambda i: i.value, layer.get_shape()) size = height * width * filters feats = tf.reshape(layer, (bs, height * width, filters)) feats_T = tf.transpose(feats, perm=[0, 2, 1]) grams = tf.matmul(feats_T, feats) / size style_gram = style_features[style_layer] style_losses.append(2 * tf.nn.l2_loss(grams - style_gram) / style_gram.size) style_loss = style_weight * functools.reduce( tf.add, style_losses) / batch_size #Loss(Style)风格损失函数 # total variation denoising tv_y_size = _tensor_size(preds[:, 1:, :, :]) tv_x_size = _tensor_size(preds[:, :, 1:, :]) y_tv = tf.nn.l2_loss(preds[:, 1:, :, :] - preds[:, :batch_shape[1] - 1, :, :]) x_tv = tf.nn.l2_loss(preds[:, :, 1:, :] - preds[:, :, :batch_shape[2] - 1, :]) tv_loss = tv_weight * 2 * (x_tv / tv_x_size + y_tv / tv_y_size) / batch_size #去噪loss值 ''' 总的loss ''' loss = content_loss + style_loss + tv_loss # overall loss ''' 梯度下降 ''' train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss) sess.run(tf.global_variables_initializer()) import random uid = random.randint(1, 100) print("UID: %s" % uid) for epoch in range(epochs): # num_examples = len(content_targets) iterations = 0 while iterations * batch_size < num_examples: #即一轮迭代 start_time = time.time() curr = iterations * batch_size step = curr + batch_size X_batch = np.zeros(batch_shape, dtype=np.float32) for j, img_p in enumerate( content_targets[curr:step]): #每batch_size个即4个一组 X_batch[j] = get_img(img_p, (256, 256, 3)).astype(np.float32) iterations += 1 assert X_batch.shape[0] == batch_size feed_dict = {X_content: X_batch} train_step.run(feed_dict=feed_dict) end_time = time.time() delta_time = end_time - start_time if debug: print("UID: %s, batch time: %s" % (uid, delta_time)) is_print_iter = int(iterations) % print_iterations == 0 if slow: is_print_iter = epoch % print_iterations == 0 ''' 判断到了设置的print_iterations轮数 和判断是做完最后一轮迭代 进行过程打印:should_print = is_print_iter or is_last ''' is_last = epoch == epochs - 1 and iterations * batch_size >= num_examples should_print = is_print_iter or is_last #打印 if should_print: to_get = [style_loss, content_loss, tv_loss, loss, preds] test_feed_dict = {X_content: X_batch} tup = sess.run(to_get, feed_dict=test_feed_dict) _style_loss, _content_loss, _tv_loss, _loss, _preds = tup losses = (_style_loss, _content_loss, _tv_loss, _loss) if slow: _preds = vgg.unprocess(_preds) else: saver = tf.train.Saver() res = saver.save(sess, save_path) #保存迭代打印内容 yield (_preds, losses, iterations, epoch) #返回值