def _make_frozen_graph(settings: SerializationSettings, graph: tf.Graph, sess: tf.Session) -> tf.GraphDef: with graph.as_default(): target_nodes = ",".join(_process_graph(settings, graph)) graph_def = graph.as_graph_def() output_graph_def = graph_util.convert_variables_to_constants( sess, graph_def, target_nodes.replace(" ", "").split(",")) return output_graph_def
def _process_graph(settings: SerializationSettings, graph: tf.Graph) -> List[str]: """ Gets the list of the output nodes present in the graph for inference :return: list of node names """ all_nodes = [x.name for x in graph.as_graph_def().node] nodes = [x for x in all_nodes if x in POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS] logger.info("List of nodes to export for brain :" + settings.brain_name) for n in nodes: logger.info("\t" + n) return nodes