import h5py
from time import time
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from tensorflow.keras.models import load_model, Model
from scipy.io import loadmat
import tensorflow.compat.v1.keras.backend as K
K.set_image_data_format('channels_first')
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)


# 先定义我们读取keras模型权重的函数
def print_keras_wegiths(weight_file_path):
    # 读取weights h5文件返回File类
    f = h5py.File(weight_file_path)
    try:
        # 读取各层的名称以及包含层信息的Group类
        for layer, g in f.items():
            print("  {}".format(layer))
            print("    Attributes:")
            # 输出储存在Group类中的attrs信息,一般是各层的weights和bias及他们的名称
            for key, value in g.attrs.items():
                print("      {}: {}".format(key, value))
    finally:
        f.close()
예제 #2
0
def keras_to_tensorflow(args):
    # If output_model path is relative and in cwd, make it absolute from root
    output_model = args.output_model
    if str(Path(output_model).parent) == '.':
        output_model = str((Path.cwd() / output_model))

    output_fld = Path(output_model).parent
    output_model_name = Path(output_model).name
    output_model_stem = Path(output_model).stem
    output_model_pbtxt_name = output_model_stem + '.pbtxt'

    # Create output directory if it does not exist
    Path(output_model).parent.mkdir(parents=True, exist_ok=True)

    if args.channels_first:
        K.set_image_data_format('channels_first')
    else:
        K.set_image_data_format('channels_last')

    custom_object_dict = get_custom_objects()

    model = load_input_model(args.input_model,
                             args.input_model_json,
                             args.input_model_yaml,
                             custom_objects=custom_object_dict)

    # TODO(amirabdi): Support networks with multiple inputs
    orig_output_node_names = [
        node.name.split(':')[0] for node in model.outputs
    ]
    #orig_output_node_names = [node.op.name for node in model.outputs]
    if args.output_nodes_prefix:
        num_output = len(orig_output_node_names)
        pred = [None] * num_output
        converted_output_node_names = [None] * num_output

        # Create dummy tf nodes to rename output
        for i in range(num_output):
            converted_output_node_names[i] = '{}{}'.format(
                args.output_nodes_prefix, i)
            pred[i] = tf.identity(model.outputs[i],
                                  name=converted_output_node_names[i])
    else:
        converted_output_node_names = orig_output_node_names
    logging.info('Converted output node names are: %s',
                 str(converted_output_node_names))

    sess = K.get_session()
    if args.output_meta_ckpt:
        saver = tf.train.Saver()
        saver.save(sess, str(output_fld / output_model_stem))

    if args.save_graph_def:
        tf.train.write_graph(sess.graph.as_graph_def(),
                             str(output_fld),
                             output_model_pbtxt_name,
                             as_text=True)
        logging.info('Saved the graph definition in ascii format at %s',
                     str(Path(output_fld) / output_model_pbtxt_name))

    if args.quantize:
        from tensorflow.tools.graph_transforms import TransformGraph
        transforms = ["quantize_weights", "quantize_nodes"]
        transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [],
                                               converted_output_node_names,
                                               transforms)
        constant_graph = graph_util.convert_variables_to_constants(
            sess, transformed_graph_def, converted_output_node_names)
    else:
        constant_graph = graph_util.convert_variables_to_constants(
            sess, sess.graph.as_graph_def(), converted_output_node_names)

    graph_io.write_graph(constant_graph,
                         str(output_fld),
                         output_model_name,
                         as_text=False)
    logging.info('Saved the freezed graph at %s',
                 str(Path(output_fld) / output_model_name))