def graph_def(self): from lpot.adaptor.tf_utils.util import _parse_ckpt_bn_input from tensorflow.python.framework import graph_util graph_def = self.sess.graph.as_graph_def() graph_def = _parse_ckpt_bn_input(graph_def) return graph_util.convert_variables_to_constants( sess=self.sess, input_graph_def=graph_def, output_node_names=self.output_node_names)
def main(args=None): tf.logging.set_verbosity(tf.logging.INFO) if not tf.gfile.Exists(FLAGS.output_dir): tf.gfile.MkDir(FLAGS.output_dir) with tf.Session() as sess: if FLAGS.input_model.rsplit('.', 1)[-1] == 'ckpt': style_img_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 3], name='style_input') content_img_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 3], name='content_input') # import meta_graph meta_data_path = FLAGS.input_model + '.meta' saver = tf.train.import_meta_graph(meta_data_path, clear_devices=True) sess.run(tf.global_variables_initializer()) saver.restore(sess, FLAGS.input_model) graph_def = sess.graph.as_graph_def() replace_style = 'style_image_processing/ResizeBilinear_2' replace_content = 'batch_processing/batch' for node in graph_def.node: for idx, input_name in enumerate(node.input): # replace style input and content input nodes to placeholder if replace_content == input_name: node.input[idx] = 'content_input' if replace_style == input_name: node.input[idx] = 'style_input' if FLAGS.tune: _parse_ckpt_bn_input(graph_def) output_name = 'transformer/expand/conv3/conv/Sigmoid' frozen_graph = tf.graph_util.convert_variables_to_constants( sess, graph_def, [output_name]) # use frozen pb instead elif FLAGS.input_model.rsplit('.', 1)[-1] == 'pb': with open(FLAGS.input_model, 'rb') as f: frozen_graph = tf.GraphDef() frozen_graph.ParseFromString(f.read()) else: print("not supported model format") exit(-1) if FLAGS.tune: with tf.Graph().as_default() as graph: tf.import_graph_def(frozen_graph, name='') quantizer = Quantization(FLAGS.config) quantizer.model = graph quantized_model = quantizer() quantized_model.save(FLAGS.output_model) frozen_graph = quantized_model.graph_def # validate the quantized model here with tf.Graph().as_default(), tf.Session() as sess: if FLAGS.tune: # create dataloader using default style_transfer dataset # generate stylized images dataset = DATASETS('tensorflow')['style_transfer']( \ FLAGS.content_images_paths.strip(), FLAGS.style_images_paths.strip(), crop_ratio=0.2, resize_shape=(256, 256)) else: dataset = DATASETS('tensorflow')['dummy']( \ shape=[(200, 256, 256, 3), (200, 256, 256, 3)], label=True) dataloader = DATALOADERS['tensorflow'](dataset=dataset, batch_size=FLAGS.batch_size) tf.import_graph_def(frozen_graph, name='') style_transfer(sess, dataloader, FLAGS.precision)