def export_graph(): graph = tf.Graph() with graph.as_default(): # Instantiate a CycleGAN cycle_gan = model.CycleGAN(ngf=64, norm="instance", image_size=FLAGS.image_size) # Create placeholder for image bitstring # This is the first injection layer input_bytes = tf.placeholder(tf.string, shape=[], name="input_bytes") # Preprocess input (bitstring to float tensor) input_tensor = preprocess_bitstring_to_float_tensor( input_bytes, FLAGS.image_size) # Get style transferred tensor output_tensor = cycle_gan.G.sample(input_tensor) # Postprocess output output_bytes = postprocess_float_tensor_to_bitstring(output_tensor) # Instantiate a Saver saver = tf.train.Saver() with tf.Session(graph=graph) as sess: sess.run(tf.global_variables_initializer()) # Access variables and weights from last checkpoint latest_ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) saver.restore(sess, latest_ckpt) # Export graph to ProtoBuf output_graph_def = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [output_bytes.op.name]) tf.train.write_graph(output_graph_def, FLAGS.protobuf_dir, FLAGS.model_name + "_v" + str(FLAGS.version), as_text=False)
def example_usage(_): sys.path.insert(0, "../CycleGAN-TensorFlow") import model # nopep8 # Instantiates a CycleGAN cycle_gan = model.CycleGAN(ngf=64, norm="instance", image_size=FLAGS.image_size) # Instantiates a ServerBuilder server_builder = ServerBuilder() # Exports model print("Exporting model to ProtoBuf...") server_builder.export_graph(cycle_gan.G.sample, FLAGS.model_name, FLAGS.model_version, FLAGS.checkpoint_dir, FLAGS.protobuf_dir, FLAGS.image_size) print("Wrapping ProtoBuf in SavedModel...") server_builder.build_saved_model(FLAGS.model_name, FLAGS.model_version, FLAGS.protobuf_dir, FLAGS.serve_dir) print("Exported successfully!") print("""Run the server with: tensorflow_model_server --rest_api_port=8501 """ "--model_name=saved_model --model_base_path=$(path)")
# plt.figure(figsize=(5,5))#图片大一点才可以承载像素 # plt.subplot(2,2,1) # plt.imshow(X[0,:,:,8],cmap='gray') # plt.axis('off') # plt.subplot(2,2,2) # plt.imshow(Y[0,:,:,8],cmap='gray') # plt.axis('off') # plt.subplot(2,2,3) # plt.imshow(mX[0,:,:,8],cmap='gray') # plt.axis('off') # plt.subplot(2,2,4) # plt.imshow(mY[0,:,:,8],cmap='gray') # plt.show() test_set = train_dataset.DataPipeLine(test_path,target_size=[240,240,155],patch_size=[128,128,16],crop="crop_centro") test_set = tf.data.Dataset.from_generator(test_set.generator,output_types=(tf.float32,tf.float32,tf.float32,tf.float32,tf.int32),output_shapes=([240,240,155],[240,240,155],[240,240,155],[240,240,155],[3,2]))\ .map(map_func,num_parallel_calls=num_threads)\ .batch(BATCH_SIZE)\ .prefetch(buffer_size = tf.data.experimental.AUTOTUNE) # for i,(X,Y) in enumerate(dataset): # print(i+1,X.shape,Y.dtype) model = model.CycleGAN(train_set=dataset, test_set=test_set, loss_name="WGAN-GP-SN", mixed_precision=True, learning_rate=1e-4, tmp_path=tmp_path, out_path=out_path) model.build(X_shape=[None,128,128,16,1],Y_shape=[None,128,128,16,1]) model.train(epoches=EPOCHES)