Ejemplo n.º 1
0
def _minimal_operations(sess):
    """Get inference operations."""
    minimal_graph_def = executor.convert_variables_to_constants(sess)
    minimal_graph = tf.Graph()
    with minimal_graph.as_default():
        tf.import_graph_def(minimal_graph_def, name="")
    ops = minimal_graph.get_operations()

    return ops
Ejemplo n.º 2
0
def _profile(config, restore_path, bit, unquant_layers):
    output_root_dir = os.path.join(environment.EXPERIMENT_DIR, "export")
    if restore_path:
        output_root_dir = os.path.join(output_root_dir,
                                       os.path.basename(restore_path))

    if not os.path.exists(output_root_dir):
        os.makedirs(output_root_dir)

    graph = tf.Graph()
    ModelClass = config.NETWORK_CLASS
    network_kwargs = dict(
        (key.lower(), val) for key, val in config.NETWORK.items())

    with graph.as_default():

        model = ModelClass(
            classes=config.CLASSES,
            is_debug=config.IS_DEBUG,
            **network_kwargs,
        )

        is_training = tf.constant(False, name="is_training")

        images_placeholder, _ = model.placeholderes()
        output = model.inference(images_placeholder, is_training)
        model.summary(output)
        init_op = tf.global_variables_initializer()
        saver = tf.train.Saver(max_to_keep=50)

    session_config = tf.ConfigProto()
    sess = tf.Session(graph=graph, config=session_config)
    sess.run(init_op)

    if restore_path:
        print("Restore from {}".format(restore_path))
        saver.restore(sess, restore_path)

    main_output_dir = os.path.join(
        output_root_dir, "{}x{}".format(config.IMAGE_SIZE[0],
                                        config.IMAGE_SIZE[1]))
    if not os.path.exists(main_output_dir):
        os.makedirs(main_output_dir)

    inference_graph_def = executor.convert_variables_to_constants(sess)

    inference_graph = tf.Graph()
    with inference_graph.as_default():
        tf.import_graph_def(inference_graph_def)

    scopes = {"_TFProfRoot": 0}
    scope_idx = 1
    for node in inference_graph_def.node:
        names = node.name.split("/")
        scope = names[0]
        if scope not in scopes:
            scopes[scope] = scope_idx
            scope_idx += 1

    # [level, node name, total param, 32 bits size, quantized size, flops]
    res = []
    res = _profile_params(graph, res, bit, unquant_layers)
    res = _profile_flops(inference_graph, res, scopes)

    name = ModelClass.__name__
    image_size = config.IMAGE_SIZE
    num_classes = len(config.CLASSES)
    _render(name, image_size, num_classes, bit, res)