예제 #1
0
파일: model.py 프로젝트: vuiseng9/lpot
 def graph_def(self):
     from lpot.adaptor.tf_utils.util import _parse_ckpt_bn_input
     from tensorflow.python.framework import graph_util
     graph_def = self.sess.graph.as_graph_def()
     graph_def = _parse_ckpt_bn_input(graph_def)
     return graph_util.convert_variables_to_constants(
         sess=self.sess,
         input_graph_def=graph_def,
         output_node_names=self.output_node_names)
예제 #2
0
def main(args=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MkDir(FLAGS.output_dir)

    with tf.Session() as sess:
        if FLAGS.input_model.rsplit('.', 1)[-1] == 'ckpt':
            style_img_ph = tf.placeholder(tf.float32,
                                          shape=[None, 256, 256, 3],
                                          name='style_input')
            content_img_ph = tf.placeholder(tf.float32,
                                            shape=[None, 256, 256, 3],
                                            name='content_input')
            # import meta_graph
            meta_data_path = FLAGS.input_model + '.meta'
            saver = tf.train.import_meta_graph(meta_data_path,
                                               clear_devices=True)

            sess.run(tf.global_variables_initializer())
            saver.restore(sess, FLAGS.input_model)
            graph_def = sess.graph.as_graph_def()

            replace_style = 'style_image_processing/ResizeBilinear_2'
            replace_content = 'batch_processing/batch'
            for node in graph_def.node:
                for idx, input_name in enumerate(node.input):
                    # replace style input and content input nodes to  placeholder
                    if replace_content == input_name:
                        node.input[idx] = 'content_input'
                    if replace_style == input_name:
                        node.input[idx] = 'style_input'

            if FLAGS.tune:
                _parse_ckpt_bn_input(graph_def)
            output_name = 'transformer/expand/conv3/conv/Sigmoid'
            frozen_graph = tf.graph_util.convert_variables_to_constants(
                sess, graph_def, [output_name])
        # use frozen pb instead
        elif FLAGS.input_model.rsplit('.', 1)[-1] == 'pb':
            with open(FLAGS.input_model, 'rb') as f:
                frozen_graph = tf.GraphDef()
                frozen_graph.ParseFromString(f.read())
        else:
            print("not supported model format")
            exit(-1)

        if FLAGS.tune:
            with tf.Graph().as_default() as graph:
                tf.import_graph_def(frozen_graph, name='')
                quantizer = Quantization(FLAGS.config)
                quantizer.model = graph
                quantized_model = quantizer()
                quantized_model.save(FLAGS.output_model)
                frozen_graph = quantized_model.graph_def

    # validate the quantized model here
    with tf.Graph().as_default(), tf.Session() as sess:
        if FLAGS.tune:
            # create dataloader using default style_transfer dataset
            # generate stylized images
            dataset = DATASETS('tensorflow')['style_transfer']( \
                FLAGS.content_images_paths.strip(),
                FLAGS.style_images_paths.strip(),
                crop_ratio=0.2,
                resize_shape=(256, 256))
        else:
            dataset = DATASETS('tensorflow')['dummy']( \
                shape=[(200, 256, 256, 3), (200, 256, 256, 3)], label=True)
        dataloader = DATALOADERS['tensorflow'](dataset=dataset,
                                               batch_size=FLAGS.batch_size)
        tf.import_graph_def(frozen_graph, name='')
        style_transfer(sess, dataloader, FLAGS.precision)