def _constructGraphDef(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
      math_ops.add(in_tensor, in_tensor, name='add')
      sess = session.Session()

    return (
        convert_to_constants.convert_variables_to_constants_from_session_graph(
            sess, sess.graph_def, ['add']))
예제 #2
0
def load_meta_graph(
        saved_model_dir: str, saved_model_tags: str,
        saved_model_signature_key: str) -> meta_graph_pb2.MetaGraphDef:
    """Loads a `tf.MetaGraphDef` in TF1."""
    with session.Session() as sess:
        meta_graph = saved_model_loader.load(
            sess=sess,
            export_dir=saved_model_dir,
            tags=saved_model_tags,
        )
        output_node_names = [
            tensor.name.split(":")[0] for tensor in meta_graph.
            signature_def[saved_model_signature_key].outputs.values()
        ]
        graph_def = convert_to_constants.convert_variables_to_constants_from_session_graph(
            sess, meta_graph.graph_def, output_node_names)
        meta_graph.graph_def.CopyFrom(graph_def)
    return meta_graph
예제 #3
0
def load_meta_graph(
    saved_model_dir: str, saved_model_tags: str,
    saved_model_signature_key: str) -> meta_graph_pb2.MetaGraphDef:
  """Loads a `tf.MetaGraphDef` in TF1."""
  with framework_ops.Graph().as_default() as graph, session.Session(
      graph=graph) as sess:
    meta_graph = saved_model_loader.load(
        sess=sess,
        export_dir=saved_model_dir,
        tags=saved_model_tags,
    )
    output_node_names = [
        _remove_graph_sequence_number(tensor.name) for tensor in
        meta_graph.signature_def[saved_model_signature_key].outputs.values()
    ]
    graph_def = (
        convert_to_constants.convert_variables_to_constants_from_session_graph(
            sess, meta_graph.graph_def, output_node_names))
    meta_graph.graph_def.CopyFrom(graph_def)
  return meta_graph