def get_metagraph():
  """Constructs and returns a MetaGraphDef from the input file."""
  if FLAGS.metagraphdef:
    with gfile.GFile(FLAGS.metagraphdef) as meta_file:
      metagraph = meta_graph_pb2.MetaGraphDef()
      if FLAGS.metagraphdef.endswith(".pbtxt"):
        text_format.Merge(meta_file.read(), metagraph)
      else:
        metagraph.ParseFromString(meta_file.read())
    if FLAGS.fetch is not None:
      fetch_collection = meta_graph_pb2.CollectionDef()
      for fetch in FLAGS.fetch.split(","):
        fetch_collection.node_list.value.append(fetch)
      metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
  else:
    with gfile.GFile(FLAGS.graphdef) as graph_file:
      graph_def = graph_pb2.GraphDef()
      if FLAGS.graphdef.endswith(".pbtxt"):
        text_format.Merge(graph_file.read(), graph_def)
      else:
        graph_def.ParseFromString(graph_file.read())
      importer.import_graph_def(graph_def, name="")
      graph = ops.get_default_graph()
      for fetch in FLAGS.fetch.split(","):
        fetch_op = graph.get_operation_by_name(fetch)
        graph.add_to_collection("train_op", fetch_op)
      metagraph = saver.export_meta_graph(
          graph_def=graph.as_graph_def(), graph=graph)
  return metagraph
Beispiel #2
0
  def testDefaultAttrsRemoved(self):
    producer_op_list = op_def_pb2.OpList()
    text_format.Merge("""
      op {
        name: 'OpWithFutureDefaultAttr'
        attr { name: 'default_int' type: 'int' default_value { i: 456 } }
      }
    """, producer_op_list)
    # Attr only in producer_op_list with default value gets removed.
    with ops.Graph().as_default():
      a = importer.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'OpWithFutureDefaultAttr'
                 attr { key: 'default_int' value { i: 456 } } }
          """),
          return_elements=["A"],
          producer_op_list=producer_op_list)
      with self.assertRaisesRegexp(ValueError, "No attr named 'default_int'"):
        a[0].get_attr("default_int")

    # Attr only in producer_op_list with non-default value is preserved.
    with ops.Graph().as_default():
      a = importer.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'OpWithFutureDefaultAttr'
                 attr { key: 'default_int' value { i: 987 } } }
          """),
          return_elements=["A"],
          producer_op_list=producer_op_list)
      self.assertEqual(987, a[0].get_attr("default_int"))
Beispiel #3
0
 def testInvalidInputForInputMap(self):
   with ops.Graph().as_default():
     with self.assertRaises(TypeError) as e:
       importer.import_graph_def(
           self._MakeGraphDef(""), input_map=[constant_op.constant(5.0)])
     self.assertEqual("input_map must be a dictionary mapping strings to "
                      "Tensor objects.", str(e.exception))
   graph_def = self._MakeGraphDef("""
        node { name: 'a' op: 'Placeholder'
               attr { key: 'dtype' value { type: DT_FLOAT } }}
        node { name: 'id' op: 'Identity' input: 'a:0'
               attr { key: 'T' value { type: DT_FLOAT } }}""")
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           graph_def,
           input_map={"a:0": variables.Variable(5.0)},
           name="")
     self.assertStartsWith(str(e.exception),
                           "tf.import_graph_def() requires a non-empty `name` "
                           "if `input_map` contains non-Tensor values.")
   with ops.Graph().as_default():
     t, = importer.import_graph_def(
         graph_def,
         input_map={"a:0": constant_op.constant(5.0)},
         name="",
         return_elements=["id:0"])
     with self.test_session():
       self.assertEqual(5.0, t.eval())
Beispiel #4
0
  def testWithDeviceFunctionDependingOnInputs(self):
    if ops._USE_C_API: return  # TODO(skyewm): make this work with C API

    with ops.Graph().as_default() as g:
      with ops.device("/job:ps"):
        v1 = constant_op.constant(1.0)
        v2 = constant_op.constant(1.0)
      _ = v1 + v2
      _ = v1 - v2
      _ = array_ops.identity(v1)
    gdef = g.as_graph_def()

    # We'll use the following device function to observe ops with two inputs.
    ops_with_two_inputs = []

    def InputCounter(op):
      if len(op.inputs) == 2:
        ops_with_two_inputs.append(op)
      return ""

    with ops.Graph().as_default() as g:
      with ops.device(InputCounter):
        importer.import_graph_def(gdef)

    # We expect to see the add and subtract, but not identity.
    self.assertEqual(2, len(ops_with_two_inputs))
Beispiel #5
0
  def testImportGraphWithFunctionTwice(self):
    g = ops.Graph()
    with g.as_default():

      @function.Defun()
      def Add2(x, y):
        return math_ops.add(x, y)

      x = array_ops.placeholder(dtype=dtypes.float32, name="x")
      y = array_ops.placeholder(dtype=dtypes.float32, name="y")
      _ = Add2(x, y, name="z")  # pylint: disable=unexpected-keyword-arg

    gdef = g.as_graph_def()

    x = random_ops.random_uniform(dtype=dtypes.float32, shape=())
    y = random_ops.random_uniform(dtype=dtypes.float32, shape=())
    input_map = {"x:0": x, "y:0": y}

    with ops.name_scope("first"):
      z1 = importer.import_graph_def(gdef, return_elements=["z:0"],
                                     input_map=input_map)[0]

    with ops.name_scope("second"):
      z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
                                     input_map=input_map)[0]

    with self.test_session() as sess:
      z1_val, z2_val = sess.run((z1, z2))
      self.assertAllEqual(z1_val, z2_val)
def run_graph_def(graph_def, input_map, outputs):
  graph = ops_lib.Graph()
  with graph.as_default():
    importer.import_graph_def(graph_def, input_map={}, name="")
  with session.Session(graph=graph) as sess:
    results = sess.run(outputs, feed_dict=input_map)
  return results
Beispiel #7
0
  def testNamePrefixColocationAttrsMultipleImport(self):
    if ops._USE_C_API: return  # TODO(skyewm): set uniquify_names

    original_graph_def = self._MakeGraphDef("""
          node { name: 'A' op: 'None' }
          node { name: 'B' op: 'None'  attr {
            key: '_class'
            value { list { s: 'loc:@A' } }
          } }""")

    with ops.Graph().as_default():
      b, = importer.import_graph_def(
          original_graph_def, return_elements=["B"], name="")
      _, = importer.import_graph_def(
          original_graph_def, return_elements=["B"], name="")
      self.assertProtoEqualsVersion("""
          node { name: 'A' op: 'None' }
          node { name: 'B' op: 'None'  attr {
            key: '_class'
            value { list { s: 'loc:@A' } }
          } }
          node { name: 'A_1' op: 'None' }
          node { name: 'B_1' op: 'None'  attr {
            key: '_class'
            value { list { s: 'loc:@A_1' } }
          } }""", b.graph.as_graph_def())
def main(_):
  if FLAGS.metagraphdef:
    with gfile.GFile(FLAGS.metagraphdef) as meta_file:
      metagraph = meta_graph_pb2.MetaGraphDef()
      metagraph.ParseFromString(meta_file.read())
  else:
    with gfile.GFile(FLAGS.graphdef) as graph_file:
      graph_def = graph_pb2.GraphDef()
      if FLAGS.graphdef.endswith(".pbtxt"):
        text_format.Merge(graph_file.read(), graph_def)
      else:
        graph_def.ParseFromString(graph_file.read())
      importer.import_graph_def(graph_def, name="")
      graph = ops.get_default_graph()
      fetch = graph.get_operation_by_name(FLAGS.fetch)
      graph.add_to_collection("train_op", fetch)
      metagraph = saver.export_meta_graph(
          graph_def=graph.as_graph_def(), graph=graph)

  if FLAGS.rewriter_config is not None:
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    text_format.Merge(FLAGS.rewriter_config, rewriter_config)
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
    metagraph.graph_def.CopyFrom(optimized_graph)

  report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
  print(report)
Beispiel #9
0
 def testMissingControlInputInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(ValueError,
                                  r"Node 'B': Unknown input node '\^A'"):
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'None' input: '^A' }
           """))
Beispiel #10
0
 def testDuplicateOperationNames(self):
   with self.assertRaisesRegexp(ValueError, "Node 'A' is not unique"):
     importer.import_graph_def(
         self._MakeGraphDef("""
         node { name: 'A' op: 'IntOutput' }
         node { name: 'B' op: 'IntOutput' }
         node { name: 'A' op: 'IntOutput' }
         """))
Beispiel #11
0
 def testMissingInputOpInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'If' input: 'A:0' }
           """))
     self.assertTrue("Input tensor 'A:0' not found" in str(e.exception))
Beispiel #12
0
 def testMissingControlInputInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'None' input: '^A' }
           """))
     self.assertTrue("Control input '^A' not found" in str(e.exception))
Beispiel #13
0
 def testMissingInputOpInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(ValueError,
                                  "Node 'B': Unknown input node 'A:0'"):
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'FloatInput' input: 'A:0' }
           """))
Beispiel #14
0
 def testInvalidTensorNameInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(ValueError,
                                  "Node 'B': Unknown input node 'A:B:0'"):
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'None' input: 'A:B:0' }
           """))
Beispiel #15
0
 def testVersionLow(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         Exception,
         r"GraphDef producer version -1 below min producer %d supported "
         r"by TensorFlow \S+\.  Please regenerate your graph.$" %
         versions.GRAPH_DEF_VERSION_MIN_PRODUCER):
       importer.import_graph_def(self._MakeGraphDef("", producer=-1))
 def SetProducerVersion(self, graph, producer_version):
   # The C API doesn't expose altering GraphDefVersions. We can indirectly set
   # it via import_graph_def though.
   graph_def = graph_pb2.GraphDef()
   graph_def.versions.producer = producer_version
   with graph.as_default():
     importer.import_graph_def(graph_def)
   assert graph.graph_def_versions.producer, producer_version
Beispiel #17
0
 def testVersionHigh(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         ValueError,
         r"GraphDef min consumer version %d above current version %d "
         r"for TensorFlow \S+\.  Please upgrade TensorFlow\.$" %
         (1 << 30, versions.GRAPH_DEF_VERSION)):
       importer.import_graph_def(self._MakeGraphDef("", min_consumer=1 << 30))
Beispiel #18
0
  def _TestTrtGraphConverter(self,
                             input_saved_model_dir=None,
                             output_saved_model_dir=None,
                             need_calibration=False,
                             is_dynamic_op=False):
    """General method to test trt_convert.TrtGraphConverter()."""
    output_graph_def = self._ConvertGraph(
        input_saved_model_dir=input_saved_model_dir,
        output_saved_model_dir=output_saved_model_dir,
        need_calibration=need_calibration,
        is_dynamic_op=is_dynamic_op,
        use_function_backup=need_calibration)
    graph_defs_to_verify = [output_graph_def]

    if output_saved_model_dir:
      if context.executing_eagerly():
        root = load.load(output_saved_model_dir)
        saved_model_graph_def = root.signatures[
            signature_constants
            .DEFAULT_SERVING_SIGNATURE_DEF_KEY].graph.as_graph_def()
      else:
        saved_model_graph_def = saved_model_utils.get_meta_graph_def(
            output_saved_model_dir, tag_constants.SERVING).graph_def
      self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
      graph_defs_to_verify.append(saved_model_graph_def)

    for graph_def in graph_defs_to_verify:
      node_name_to_op = {node.name: node.op for node in graph_def.node}
      if context.executing_eagerly():
        # In V2 the actual graph could be inside a function.
        for func in graph_def.library.function:
          node_name_to_op.update({node.name: node.op for node in func.node_def})
        self.assertIn("TRTEngineOp_0", node_name_to_op)
        self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"])
      else:
        self.assertEqual({
            "input": "Placeholder",
            "TRTEngineOp_0": "TRTEngineOp",
            "output": "Identity"
        }, node_name_to_op)

      if need_calibration:
        trt_engine_nodes = [
            node for node in graph_def.node if node.op == "TRTEngineOp"
        ]
        self.assertNotEmpty(trt_engine_nodes)
        for node in trt_engine_nodes:
          self.assertTrue(len(node.attr["calibration_data"].s))
        # Run the calibrated graph.
        # TODO(laigd): consider having some input where the answer is different.
        with ops.Graph().as_default():
          importer.import_graph_def(graph_def, name="")
          with self.session(config=self._GetConfigProto()) as sess:
            for test_data in range(10):
              self.assertEqual((test_data + 1.0)**2,
                               sess.run(
                                   "output:0",
                                   feed_dict={"input:0": [[[test_data]]]}))
Beispiel #19
0
 def testMissingInputMap(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'A' op: 'None' }
           """),
           input_map={"B:0": constant_op.constant(5.0)})
     self.assertTrue("not found in graph_def: [B:0]" in str(e.exception))
Beispiel #20
0
 def testInvalidTensorNameInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'None' input: 'A:B:0' }
           """))
     self.assertEqual("Cannot convert 'A:B:0' to a tensor name.",
                      str(e.exception))
Beispiel #21
0
 def testMissingReturnOperation(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         ValueError, "Requested return node 'B' not found in graph def"):
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'A' op: 'None' }
           """),
           return_elements=["B"])
Beispiel #22
0
 def testNamePrefixColocationAttrsNotFound(self):
   original_graph_def = self._MakeGraphDef("""
         node { name: 'B' op: 'None'  attr {
           key: '_class'
           value { list { s: 'loc:@A' } }
         } }""")
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(ValueError, "does not exist during import"):
       importer.import_graph_def(
           original_graph_def, return_elements=["B"], name="imported_graph")
Beispiel #23
0
 def testDuplicateOperationNames(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'A' op: 'Oi' }
           node { name: 'B' op: 'Oi' }
           node { name: 'A' op: 'Oi' }
           """))
     self.assertEqual("Duplicate name 'A' in GraphDef.", str(e.exception))
Beispiel #24
0
 def testMissingReturnOperation(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'A' op: 'None' }
           """),
           return_elements=["B"])
     self.assertTrue(
         "return_element 'B' not found in graph_def." in str(e.exception))
Beispiel #25
0
  def testInvalidInputForReturnOperations(self):
    with ops.Graph().as_default():
      with self.assertRaisesRegexp(
          TypeError, "return_elements must be a list of strings."):
        importer.import_graph_def(self._MakeGraphDef(""), return_elements=[7])

      with self.assertRaisesRegexp(ValueError,
                                   "Cannot convert 'a:b:c' to a tensor name."):
        importer.import_graph_def(
            self._MakeGraphDef(""), return_elements=["a:b:c"])
Beispiel #26
0
 def testInvalidSignatureNotEnoughInputsInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'A' op: 'Oi' }
           node { name: 'B' op: 'Iif' input: 'A:0' }
           """))
     self.assertTrue("Input types mismatch (expected 'int32, float32' but "
                     "got 'int32')" in str(e.exception))
Beispiel #27
0
  def _convert_graph_def(self):
    """Convert the input GraphDef."""
    graph = ops.Graph()
    with graph.as_default():
      importer.import_graph_def(self._input_graph_def, name="")
    self._grappler_meta_graph_def = saver.export_meta_graph(
        graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
    self._add_nodes_blacklist()

    self._run_conversion()
Beispiel #28
0
 def testMissingInputMap(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         ValueError,
         r"Attempted to map inputs that were not found in graph_def: \[B:0\]"):
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'A' op: 'None' }
           """),
           input_map={"B:0": constant_op.constant(5.0)})
Beispiel #29
0
  def testMissingControlInputInGraphDef(self):
    if ops._USE_C_API: return  # TODO(skyewm): make this work with C API

    with ops.Graph().as_default():
      with self.assertRaises(ValueError) as e:
        importer.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'B' op: 'None' input: '^A' }
            """))
      self.assertTrue("Control input '^A' not found" in str(e.exception))
Beispiel #30
0
 def testInvalidSignatureTooManyInputsInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'A' op: 'Oi' }
           node { name: 'B' op: 'None' input: 'A:0' }
           """))
     self.assertTrue("More inputs specified ('A:0') than the op expects" in
                     str(e.exception))
Beispiel #31
0
    def testOldGraph(self):
        # Load graph generated from earlier version of TF where
        # placeholder shape was not set.
        #
        # a = tf.placeholder(tf.float32)
        # b = a + 1.0
        #
        # Older graph's default shape is 'shape {}', not 'shape {
        # unknown_rank: true }'
        graph = """
node {
  name: "Placeholder"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "add/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 1.0
      }
    }
  }
}
node {
  name: "add"
  op: "Add"
  input: "Placeholder"
  input: "add/y"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
versions {
  producer: 21
}
"""
        gdef = graph_pb2.GraphDef()
        text_format.Merge(graph, gdef)
        with self.test_session():
            p, ret = importer.import_graph_def(
                gdef, return_elements=["Placeholder:0", "add:0"])

            # Feed in a vector of two elements.  Since the producer version
            # of 21, a shape of {} is interpreted as "any shape".  If
            # producer version were 22, then we'd get a shape mismatch
            # error.
            self.assertAllEqual([2.0, 3.0],
                                ret.eval(feed_dict={p: [1.0, 2.0]}))
Beispiel #32
0
 def testInvalidInputForGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaises(TypeError) as e:
       importer.import_graph_def("")
     self.assertEqual("graph_def must be a GraphDef proto.", str(e.exception))
 def _imports_graph_def():
     importer.import_graph_def(graph_def, name="")
    def _TestTrtGraphConverter(self,
                               input_saved_model_dir=None,
                               output_saved_model_dir=None,
                               need_calibration=False,
                               is_dynamic_op=False):
        """General method to test trt_convert.TrtGraphConverter()."""
        output_graph_def = self._ConvertGraph(
            input_saved_model_dir=input_saved_model_dir,
            output_saved_model_dir=output_saved_model_dir,
            need_calibration=need_calibration,
            is_dynamic_op=is_dynamic_op,
            use_function_backup=need_calibration)
        graph_defs_to_verify = [output_graph_def]

        if output_saved_model_dir:
            if context.executing_eagerly():
                root = load.load(output_saved_model_dir)
                saved_model_graph_def = root.signatures[
                    _SAVED_MODEL_SIGNATURE_KEY].graph.as_graph_def()
            else:
                saved_model_graph_def = saved_model_utils.get_meta_graph_def(
                    output_saved_model_dir, tag_constants.SERVING).graph_def
            self.assertTrue(
                isinstance(saved_model_graph_def, graph_pb2.GraphDef))
            graph_defs_to_verify.append(saved_model_graph_def)

        for graph_def in graph_defs_to_verify:
            node_name_to_op = {node.name: node.op for node in graph_def.node}
            if context.executing_eagerly():
                # In V2 the actual graph could be inside a function.
                for func in graph_def.library.function:
                    node_name_to_op.update(
                        {node.name: node.op
                         for node in func.node_def})
                self.assertIn("TRTEngineOp_0", node_name_to_op)
                self.assertEqual("TRTEngineOp",
                                 node_name_to_op["TRTEngineOp_0"])
            else:
                self.assertEqual(
                    {
                        "input": "Placeholder",
                        "TRTEngineOp_0": "TRTEngineOp",
                        "output": "Identity"
                    }, node_name_to_op)

            if need_calibration:
                trt_engine_nodes = [
                    node for node in graph_def.node if node.op == "TRTEngineOp"
                ]
                self.assertNotEmpty(trt_engine_nodes)
                for node in trt_engine_nodes:
                    self.assertTrue(len(node.attr["calibration_data"].s))
                # Run the calibrated graph.
                # TODO(laigd): consider having some input where the answer is different.
                with ops.Graph().as_default():
                    importer.import_graph_def(graph_def, name="")
                    with self.session(config=self._GetConfigProto()) as sess:
                        for test_data in range(10):
                            self.assertEqual(
                                (test_data + 1.0)**2,
                                sess.run(
                                    "output:0",
                                    feed_dict={"input:0": [[[test_data]]]}))
Beispiel #35
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_blacklist=""):
    """Converts all variables in a graph and checkpoint into constants."""

    # Unused by updated loading code.
    del restore_op_name, filename_tensor_name

    if not gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if input_saver and not gfile.Exists(input_saver):
        print("Input saver file '" + input_saver + "' does not exist!")
        return -1

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = graph_pb2.GraphDef()
    mode = "rb" if input_binary else "r"
    with gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read().decode("utf-8"), input_graph_def)
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""

    _ = importer.import_graph_def(input_graph_def, name="")

    with session.Session() as sess:
        if input_saver:
            with gfile.FastGFile(input_saver, mode) as f:
                saver_def = saver_pb2.SaverDef()
                if input_binary:
                    saver_def.ParseFromString(f.read())
                else:
                    text_format.Merge(f.read(), saver_def)
                saver = saver_lib.Saver(saver_def=saver_def)
                saver.restore(sess, input_checkpoint)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                except KeyError:
                    continue
                var_list[key] = tensor

            # """ Print ops name
            def _node_name(n):
                if n.startswith("^"):
                    return n[1:]
                else:
                    return n.split(":")[0]

            name_to_node_map = {}  # Keyed by node name.
            for node in input_graph_def.node:
                n = _node_name(node.name)
                name_to_node_map[n] = node

            # print(name_to_node_map.keys())
            # """
            saver = saver_lib.Saver(var_list=var_list)
            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes)

        variable_names_blacklist = (variable_names_blacklist.split(",")
                                    if variable_names_blacklist else None)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess, input_graph_def, output_node_names.split(","))
        # variable_names_blacklist=variable_names_blacklist)

    with gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))
Beispiel #36
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_blacklist=""):
  """Converts all variables in a graph and checkpoint into constants."""

  if not gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if input_saver and not gfile.Exists(input_saver):
    print("Input saver file '" + input_saver + "' does not exist!")
    return -1

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not saver_lib.checkpoint_exists(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = graph_pb2.GraphDef()
  mode = "rb" if input_binary else "r"
  with gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)
  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""
  _ = importer.import_graph_def(input_graph_def, name="")

  with session.Session() as sess:
    if input_saver:
      with gfile.FastGFile(input_saver, mode) as f:
        saver_def = saver_pb2.SaverDef()
        if input_binary:
          saver_def.ParseFromString(f.read())
        else:
          text_format.Merge(f.read(), saver_def)
        saver = saver_lib.Saver(saver_def=saver_def)
        saver.restore(sess, input_checkpoint)
    else:
      sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
      if initializer_nodes:
        sess.run(initializer_nodes)

    variable_names_blacklist = (variable_names_blacklist.split(",") if
                                variable_names_blacklist else None)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        output_node_names.split(","),
        variable_names_blacklist=variable_names_blacklist)

  with gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))
Beispiel #37
0
  def save(self, output_saved_model_dir):
    """Save the converted graph as a SavedModel.

    Args:
      output_saved_model_dir: construct a SavedModel using the converted
        GraphDef and save it to the specified directory. This option only works
        when the input graph is loaded from a SavedModel, i.e. when
        input_saved_model_dir is specified and input_graph_def is None in
        __init__().

    Raises:
      ValueError: if the input to the converter is a GraphDef instead of a
      SavedModel.
    """
    assert self._converted
    if self._need_calibration:
      assert self._calibration_data_collected
    if self._input_graph_def:
      raise ValueError(
          "Not able to save to a SavedModel since input is a GraphDef")

    def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
      """Restores collections that we need to keep."""
      scope = ""
      for key in collection_keys:
        collection_def = src_meta_graph_def.collection_def[key]
        kind = collection_def.WhichOneof("kind")
        if kind is None:
          tf_logging.error(
              "Cannot identify data type for collection %s. Skipping.", key)
          continue
        from_proto = ops.get_from_proto_function(key)
        if from_proto and kind == "bytes_list":
          proto_type = ops.get_collection_proto_type(key)
          # It is assumed that there are no Variables Keys in collections
          for value in collection_def.bytes_list.value:
            proto = proto_type()
            proto.ParseFromString(value)
            try:
              new_value = from_proto(proto, import_scope=scope)
            except:
              continue
            dest_graph.add_to_collection(key, new_value)
        else:
          field = getattr(collection_def, kind)
          if kind == "node_list":
            for value in field.value:
              name = ops.prepend_name_scope(value, scope)
              # Since the graph has been optimized, the node may no longer
              # exists
              try:
                col_op = dest_graph.as_graph_element(name)
              except (TypeError, ValueError, KeyError) as e:
                continue
              dest_graph.add_to_collection(key, col_op)
          elif kind == "int64_list":
            # NOTE(opensource): This force conversion is to work around the
            # fact that Python2 distinguishes between int and long, while
            # Python3 has only int.
            for value in field.value:
              dest_graph.add_to_collection(key, int(value))
          else:
            for value in field.value:
              dest_graph.add_to_collection(key,
                                           ops.prepend_name_scope(value, scope))

    # Write the transformed graphdef as SavedModel.
    saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
    with ops.Graph().as_default():
      importer.import_graph_def(self._converted_graph_def, name="")
      _restore_collections(
          ops.get_default_graph(), self._grappler_meta_graph_def,
          self._collections_to_keep(
              self._grappler_meta_graph_def.collection_def))
      # We don't use any specific converter here.
      with session.Session(config=self._session_config) as sess:
        saved_model_builder.add_meta_graph_and_variables(
            sess,
            self._input_saved_model_tags,
            signature_def_map=self._grappler_meta_graph_def.signature_def)
    # Ignore other meta graphs from the input SavedModel.
    saved_model_builder.save()
Beispiel #38
0
  def testColocationWithDeviceFn(self):
    original_graph_def = self._MakeGraphDef("""
          node { name: 'A' op: 'None' attr {
            key: '_class'
            value { list { s: 'loc:@A' } }
          } }
          node { name: 'B' op: 'None'  attr {
            key: '_class'
            value { list { s: 'loc:@A' } }
          } }""")

    # A device function that places "A" on one device and "B" on
    # another device.  Because B is colocated with A, we test that B's
    # device function is overridden by A.
    def CustomDeviceFn(op):
      if "A" in op.name:
        return "/device:A:0"
      else:
        return "/device:B:0"

    with ops.Graph().as_default():
      with ops.device(CustomDeviceFn):
        a, b = importer.import_graph_def(original_graph_def,
                                         return_elements=["A", "B"],
                                         name="imported_graph")
      self.assertEqual(a.device, "/device:A:0")
      self.assertEqual(b.device, "/device:A:0")
      self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
      self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])

    # Test a scenario where 'A' doesn't get a device; 'A' should not have a
    # device, but during runtime will get colocated with 'B' because of the
    # colocation attribute. B's device function is still overridden by A.
    def BDeviceFn(op):
      if "B" in op.name:
        return "/device:B:0"
      return ""

    with ops.Graph().as_default():
      with ops.device(BDeviceFn):
        a, b = importer.import_graph_def(original_graph_def,
                                         return_elements=["A", "B"],
                                         name="imported_graph")
      self.assertEqual(a.device, "")
      self.assertEqual(b.device, "")
      self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
      self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])

    # Only A gets a device, so B inherits it implicitly.
    def ADeviceFn(op):
      if "A" in op.name:
        return "/device:A:0"
      return ""

    with ops.Graph().as_default():
      with ops.device(ADeviceFn):
        a, b = importer.import_graph_def(original_graph_def,
                                         return_elements=["A", "B"],
                                         name="imported_graph")
      self.assertEqual(a.device, "/device:A:0")
      self.assertEqual(b.device, "/device:A:0")
      self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
      self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])
Beispiel #39
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants.

  Args:
    input_graph_def: A `GraphDef`.
    input_saver_def: A `SaverDef` (optional).
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated string of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph_def: A `MetaGraphDef` (optional),
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
                           and variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format (optional).
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2)

  Returns:
    Location of the output_graph_def.
  """
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not checkpoint_management.checkpoint_exists(input_checkpoint)):
        raise ValueError("Input checkpoint '" + input_checkpoint +
                         "' doesn't exist!")

    if not output_node_names:
        raise ValueError(
            "You need to supply the name of a node to --output_node_names.")

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()

            # List of all partition variables. Because the condition is heuristic
            # based, the list could include false positives.
            all_parition_variable_names = [
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ]
            has_partition_var = False

            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                    if any(key in name
                           for name in all_parition_variable_names):
                        has_partition_var = True
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor

            try:
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
            except TypeError as e:
                # `var_list` is required to be a map of variable names to Variable
                # tensors. Partition variables are Identity tensors that cannot be
                # handled by Saver.
                if has_partition_var:
                    raise ValueError(
                        "Models containing partition variables cannot be converted "
                        "from checkpoint files. Please pass in a SavedModel using "
                        "the flag --input_saved_model_dir.")
                # Models that have been frozen previously do not contain Variables.
                elif _has_no_variables(sess):
                    raise ValueError(
                        "No variables were found in this model. It is likely the model "
                        "was frozen previously. You cannot freeze a graph twice."
                    )
                    return 0
                else:
                    raise e

            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_blacklist = (variable_names_blacklist.replace(
            " ", "").split(",") if variable_names_blacklist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs",
                             restore_collections_predicate=(lambda key: True)):
  """Recreates a `Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if context.executing_eagerly():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "eager execution is enabled.")
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value
                                     if not input_map or v not in input_map]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""

    scope_to_prepend_to_names = graph.unique_name(
        import_scope or "", mark_as_used=False)

    importer.import_graph_def(
        input_graph_def, name=(import_scope or ""), input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    variable_objects = {}
    for key, col_def in sorted(meta_graph_def.collection_def.items()):
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue
      if not restore_collections_predicate(key):
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto and kind == "bytes_list":
        proto_type = ops.get_collection_proto_type(key)
        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
          for value in col_def.bytes_list.value:
            variable = variable_objects.get(value, None)
            if variable is None:
              proto = proto_type()
              proto.ParseFromString(value)
              variable = from_proto(
                  proto, import_scope=scope_to_prepend_to_names)
              variable_objects[value] = variable
            graph.add_to_collection(key, variable)
        else:
          for value in col_def.bytes_list.value:
            proto = proto_type()
            proto.ParseFromString(value)
            graph.add_to_collection(
                key, from_proto(
                    proto, import_scope=scope_to_prepend_to_names))
      else:
        field = getattr(col_def, kind)
        if key in _COMPAT_COLLECTION_LIST:
          logging.warning(
              "The saved meta_graph is possibly from an older release:\n"
              "'%s' collection should be of type 'byte_list', but instead "
              "is of type '%s'.", key, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, scope_to_prepend_to_names))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope_to_prepend_to_names)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v

  return var_list
Beispiel #41
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_blacklist=""):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""

    _ = importer.import_graph_def(input_graph_def, name="")

    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def)
            saver.restore(sess, input_checkpoint)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor
            saver = saver_lib.Saver(var_list=var_list)
            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes)

        variable_names_blacklist = (variable_names_blacklist.split(",")
                                    if variable_names_blacklist else None)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_node_names.split(","),
            variable_names_blacklist=variable_names_blacklist)

    with gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))
    def testFunctions(self):
        dtype = dtypes.float32

        @function.Defun(dtype, dtype, dtype, dtype)
        def Grad(x, y, dout1, dout2):  # pylint: disable=unused-argument
            # Return the inputs for simplicity of testing. The correct return value
            # would be (dout1 + dout2, dout1 - dout2)
            return x, y

        @function.Defun(dtype, dtype, grad_func=Grad)
        def FuncWithGrad(x, y):
            return x + y, x - y

        @function.Defun(dtypes.int32)
        def ExternalTensorFunc(x):
            # c must be defined in the containing graph
            return x + c

        @function.Defun(dtypes.int32, dtypes.int32)
        def OuterFunc(x, y):
            @function.Defun(dtypes.int32)
            def InnerFunc(x):
                return x + x

            return InnerFunc(x) + y

        # Create graph with function calls and export to GraphDef
        with ops.Graph().as_default() as g1:
            p1 = array_ops.placeholder(dtype, name="p1")
            p2 = array_ops.placeholder(dtype, name="p2")
            # pylint: disable=unexpected-keyword-arg
            a, b = FuncWithGrad(p1, p2, name="f")

            c = constant_op.constant(10, dtype=dtypes.int32)
            ExternalTensorFunc(1, name="external")

            OuterFunc(10, 1, name="outer")
            # pylint: enable=unexpected-keyword-arg

        gdef = g1.as_graph_def()

        # Import GraphDef into new graph, add imported gradients, and test that
        # imported functions can be run
        with ops.Graph().as_default() as g2:
            p1, p2, a, b = importer.import_graph_def(
                gdef, return_elements=["p1:0", "p2:0", "f:0", "f:1"], name="")
            grad = gradients_impl.gradients([a], [p1, p2])

            with self.test_session(graph=g2) as sess:
                feed_dict = {p1: 1, p2: 2}
                a_val, b_val, grad_val = sess.run([a, b, grad],
                                                  feed_dict=feed_dict)
                self.assertEqual(a_val, 3.0)
                self.assertEqual(b_val, -1.0)
                # Grad function returns inputs values for testing
                self.assertEqual(grad_val, [1.0, 2.0])
                self.assertEqual(sess.run("external:0"), 11)
                self.assertEqual(sess.run("outer:0"), 21)

        # Export the new graph and reimport to test that imported functions can be
        # successfully exported/imported again
        gdef = g2.as_graph_def()
        with ops.Graph().as_default() as g3:
            p1, p2, a, b = importer.import_graph_def(
                gdef, return_elements=["p1:0", "p2:0", "f:0", "f:1"], name="")
            # Create new gradient functions (in additional to the imported gradient
            # functions created in g2).
            grad = gradients_impl.gradients([a], [p1, p2])

            with self.test_session(graph=g3) as sess:
                feed_dict = {p1: 1, p2: 2}
                a_val, b_val, grad_val = sess.run([a, b, grad],
                                                  feed_dict=feed_dict)
                self.assertEqual(a_val, 3.0)
                self.assertEqual(b_val, -1.0)
                self.assertEqual(grad_val, [1.0, 2.0])
                self.assertEqual(sess.run("external:0"), 11)
                self.assertEqual(sess.run("outer:0"), 21)
 def TestFunc():
     return importer.import_graph_def(gdef, return_elements=["z:0"])[0]
Beispiel #44
0
  def calibrate(self,
                fetch_names,
                num_runs,
                feed_dict_fn=None,
                input_map_fn=None):
    """Run the calibration and return the calibrated GraphDef.

    Args:
      fetch_names: a list of output tensor name to fetch during calibration.
      num_runs: number of runs of the graph during calibration.
      feed_dict_fn: a function that returns a dictionary mapping input names (as
        strings) in the GraphDef to be calibrated to values (e.g. Python list,
        numpy arrays, etc). One and only one of `feed_dict_fn` and
        `input_map_fn` should be specified.
      input_map_fn: a function that returns a dictionary mapping input names (as
        strings) in the GraphDef to be calibrated to Tensor objects. The values
        of the named input tensors in the GraphDef to be calibrated will be
        re-mapped to the respective `Tensor` values during calibration. One and
        only one of `feed_dict_fn` and `input_map_fn` should be specified.

    Raises:
      ValueError: if the input combination is invalid.
      RuntimeError: if this method is called in eager mode.

    Returns:
      The GraphDef after the calibration.
    """
    assert self._converted
    assert self._need_calibration
    assert not self._calibration_data_collected

    if (feed_dict_fn and input_map_fn) or (not feed_dict_fn and
                                           not input_map_fn):
      raise ValueError(
          "Should specify one and only one of feed_dict_fn and input_map_fn.")

    if input_map_fn:
      for k, v in input_map_fn().items():
        if not isinstance(k, str):
          raise ValueError("Keys of input_map_fn must be of type str")
        if not isinstance(v, tf.Tensor):
          raise ValueError("Values of input_map_fn must be of type tf.Tensor")

    self._calibration_graph = ops.Graph()
    with self._calibration_graph.as_default():
      fetches = importer.import_graph_def(
          self._converted_graph_def,
          input_map=input_map_fn() if input_map_fn else None,
          return_elements=fetch_names,
          name="")

    with session.Session(
        graph=self._calibration_graph,
        config=self._session_config) as calibration_sess:
      for _ in range(num_runs):
        calibration_sess.run(
            fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None)

      # Maps device name to the corresponding get_calibration_data.
      #
      # TODO(laigd): a better way would be to use calibration_sess to list
      # all the devices, add one get_calibration_data for each device, and
      # fetch each such op for every resource until its found. This can work
      # even when the device of the TRTEngineOp is empty or not fully specified.
      device_to_get_resource_op_map = {}

      with self._calibration_graph.as_default():
        resource_name_input = array_ops.placeholder(dtypes.string)

        for node in self._converted_graph_def.node:
          if node.op == _TRT_ENGINE_OP_NAME:
            # Adds the get_calibration_data op for the device if not done
            # before. We only add one such op for each device.
            # TODO(laigd): What if the device is empty?????
            if node.device not in device_to_get_resource_op_map:
              with self._calibration_graph.device(node.device):
                serialized_resources_output = (
                    gen_trt_ops.get_calibration_data_op(resource_name_input))
              device_to_get_resource_op_map[node.device] = (
                  serialized_resources_output)

            # Get the calibration resource.
            calibration_result = calibration_sess.run(
                device_to_get_resource_op_map[node.device],
                feed_dict={
                    resource_name_input: _get_canonical_engine_name(node.name)
                })
            node.attr["calibration_data"].s = calibration_result

      self._calibration_data_collected = True

    return self._converted_graph_def
    def _test_variable_to_const_conversion(self, use_resource):
        with ops.Graph().as_default():
            with variable_scope.variable_scope("", use_resource=use_resource):
                variable_node = variable_scope.get_variable("variable_node",
                                                            initializer=1.0)
                another_variable = variable_scope.get_variable(
                    "unused_variable_node", initializer=1.0)
                output_node = math_ops_lib.multiply(variable_node,
                                                    2.0,
                                                    name="output_node")
                with session.Session() as sess:
                    self.evaluate(variable_node.initializer)
                    output = self.evaluate(output_node)
                    self.assertNear(2.0, output, 0.00001)
                    variable_graph_def = sess.graph.as_graph_def()
                    # First get the constant_graph_def when variable_names_whitelist is
                    # set, note that if variable_names_whitelist is not set an error will
                    # be thrown because unused_variable_node is not initialized.
                    constant_graph_def = graph_util.convert_variables_to_constants(
                        sess,
                        variable_graph_def, ["output_node"],
                        variable_names_whitelist=set(["variable_node"]))

                    # Then initialize the unused variable, and get another
                    # constant_graph_def when variable_names_whitelist is not set.
                    self.evaluate(another_variable.initializer)
                    constant_graph_def_without_variable_whitelist = (
                        graph_util.convert_variables_to_constants(
                            sess, variable_graph_def, ["output_node"]))

                    # The unused variable should be cleared so the two graphs should be
                    # equivalent.
                    self.assertEqual(
                        str(constant_graph_def),
                        str(constant_graph_def_without_variable_whitelist))

                    # Test variable name black list. This should result in the variable
                    # not being a const.
                    constant_graph_def_with_blacklist = (
                        graph_util.convert_variables_to_constants(
                            sess,
                            variable_graph_def, ["output_node"],
                            variable_names_blacklist=set(["variable_node"])))
                    variable_node = None
                    for node in constant_graph_def_with_blacklist.node:
                        if node.name == "variable_node":
                            variable_node = node
                    self.assertIsNotNone(variable_node)
                    if use_resource:
                        self.assertEqual(variable_node.op, "VarHandleOp")
                    else:
                        self.assertEqual(variable_node.op, "VariableV2")

        # Now we make sure the variable is now a constant, and that the graph still
        # produces the expected result.
        with ops.Graph().as_default():
            _ = importer.import_graph_def(constant_graph_def, name="")
            self.assertEqual(4, len(constant_graph_def.node))
            self._ensure_no_variables_in_graph(constant_graph_def)
            with session.Session() as sess:
                output_node = sess.graph.get_tensor_by_name("output_node:0")
                output = self.evaluate(output_node)
                self.assertNear(2.0, output, 0.00001)
Beispiel #46
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from tensorflow.core.framework import graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.summary import summary

model_dir = 'D:/data/glomeruli/20180202_glomeruli_detection_noquant.pb'
log_dir = 'd:/temp/tf'

with session.Session(graph=ops.Graph()) as sess:
    with gfile.FastGFile(model_dir, "rb") as f:
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(f.read())
        importer.import_graph_def(graph_def)

        # pb_visual_writer = summary.FileWriter(log_dir)
        # pb_visual_writer.add_graph(sess.graph)
        file_writer = summary.FileWriter(log_dir, sess.graph)
        print("Model Imported. Visualize by running: tensorboard --logdir={}".
              format(log_dir))
Beispiel #47
0
def import_scoped_meta_graph_with_return_elements(
        meta_graph_or_file,
        clear_devices=False,
        graph=None,
        import_scope=None,
        input_map=None,
        unbound_inputs_col_name="unbound_inputs",
        restore_collections_predicate=(lambda key: True),
        return_elements=None):
    """Imports graph from `MetaGraphDef` and returns vars and return elements.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.
    return_elements:  A list of strings containing operation names in the
      `MetaGraphDef` that will be returned as `Operation` objects; and/or
      tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.

  Returns:
    A tuple of (
      dictionary of all the `Variables` imported into the name scope,
      list of `Operation` or `Tensor` objects from the `return_elements` list).

  Raises:
    ValueError: If the graph_def contains unbound inputs.

  """
    if context.executing_eagerly():
        raise ValueError(
            "Exporting/importing meta graphs is not supported when "
            "eager execution is enabled.")
    if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
        meta_graph_def = meta_graph_or_file
    else:
        meta_graph_def = read_meta_graph_file(meta_graph_or_file)

    if unbound_inputs_col_name:
        for key, col_def in meta_graph_def.collection_def.items():
            if key == unbound_inputs_col_name:
                kind = col_def.WhichOneof("kind")
                field = getattr(col_def, kind)
                if field.value and (not input_map or sorted(
                    [compat.as_str(v)
                     for v in field.value]) != sorted(input_map)):
                    raise ValueError(
                        "Graph contains unbound inputs: %s. Must "
                        "provide these inputs through input_map." % ",".join(
                            compat.as_str(v) for v in field.value
                            if not input_map or v not in input_map))
                break

    # Sets graph to default graph if it's not passed in.
    graph = graph or ops.get_default_graph()

    # Gathers the list of nodes we are interested in.
    with graph.as_default():
        producer_op_list = None
        if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
            producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
        input_graph_def = meta_graph_def.graph_def
        # Remove all the explicit device specifications for this node. This helps to
        # make the graph more portable.
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""

        scope_to_prepend_to_names = graph.unique_name(import_scope or "",
                                                      mark_as_used=False)

        imported_return_elements = importer.import_graph_def(
            input_graph_def,
            name=(import_scope or scope_to_prepend_to_names),
            input_map=input_map,
            producer_op_list=producer_op_list,
            return_elements=return_elements)

        # TensorFlow versions before 1.9 (not inclusive) exported SavedModels
        # without a VariableDef.trainable field set.
        tf_version = meta_graph_def.meta_info_def.tensorflow_version
        if not tf_version:
            variables_have_trainable = True
        else:
            variables_have_trainable = (
                distutils_version.LooseVersion(tf_version) >=
                distutils_version.LooseVersion("1.9"))

        # Sort collections so we see TRAINABLE_VARIABLES first and can default these
        # variables to trainable if the value is not set in their VariableDef.
        sorted_collections = []
        if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
            sorted_collections.append((ops.GraphKeys.TRAINABLE_VARIABLES,
                                       meta_graph_def.collection_def[
                                           ops.GraphKeys.TRAINABLE_VARIABLES]))
        for key, value in sorted(meta_graph_def.collection_def.items()):
            if key != ops.GraphKeys.TRAINABLE_VARIABLES:
                sorted_collections.append((key, value))

        # Restores all the other collections.
        variable_objects = {}
        for key, col_def in sorted_collections:
            # Don't add unbound_inputs to the new graph.
            if key == unbound_inputs_col_name:
                continue
            if not restore_collections_predicate(key):
                continue

            kind = col_def.WhichOneof("kind")
            if kind is None:
                logging.error(
                    "Cannot identify data type for collection %s. Skipping.",
                    key)
                continue
            from_proto = ops.get_from_proto_function(key)

            # Temporary change to allow the TFMA evaluator to read metric variables
            # saved as a bytes list.
            # TODO(kathywu): Remove this hack once cl/248406059 has been submitted.
            if key == ops.GraphKeys.METRIC_VARIABLES:
                # Metric variables will use the same proto functions as GLOBAL_VARIABLES
                from_proto = ops.get_from_proto_function(
                    ops.GraphKeys.GLOBAL_VARIABLES)
            if from_proto and kind == "bytes_list":
                proto_type = ops.get_collection_proto_type(key)
                if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
                    for value in col_def.bytes_list.value:
                        variable = variable_objects.get(value, None)
                        if variable is None:
                            proto = proto_type()
                            proto.ParseFromString(value)
                            if not variables_have_trainable:
                                # If the VariableDef proto does not contain a "trainable"
                                # property because it was exported before that property was
                                # added, we default it to whether the variable is in the
                                # TRAINABLE_VARIABLES collection. We've sorted
                                # TRAINABLE_VARIABLES to be first, so trainable variables will
                                # be created from that collection.
                                proto.trainable = (
                                    key == ops.GraphKeys.TRAINABLE_VARIABLES)
                            variable = from_proto(
                                proto, import_scope=scope_to_prepend_to_names)
                            variable_objects[value] = variable
                        graph.add_to_collection(key, variable)
                else:
                    for value in col_def.bytes_list.value:
                        proto = proto_type()
                        proto.ParseFromString(value)
                        graph.add_to_collection(
                            key,
                            from_proto(proto,
                                       import_scope=scope_to_prepend_to_names))
            else:
                field = getattr(col_def, kind)
                if key in _COMPAT_COLLECTION_LIST:
                    logging.warning(
                        "The saved meta_graph is possibly from an older release:\n"
                        "'%s' collection should be of type 'byte_list', but instead "
                        "is of type '%s'.", key, kind)
                if kind == "node_list":
                    for value in field.value:
                        col_op = graph.as_graph_element(
                            ops.prepend_name_scope(value,
                                                   scope_to_prepend_to_names))
                        graph.add_to_collection(key, col_op)
                elif kind == "int64_list":
                    # NOTE(opensource): This force conversion is to work around the fact
                    # that Python2 distinguishes between int and long, while Python3 has
                    # only int.
                    for value in field.value:
                        graph.add_to_collection(key, int(value))
                else:
                    for value in field.value:
                        graph.add_to_collection(
                            key,
                            ops.prepend_name_scope(value,
                                                   scope_to_prepend_to_names))

        var_list = {}
        variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                         scope=scope_to_prepend_to_names)
        for v in variables:
            var_list[ops.strip_name_scope(v.name,
                                          scope_to_prepend_to_names)] = v

    return var_list, imported_return_elements
Beispiel #48
0
    def _TestTrtGraphConverter(self,
                               device,
                               output_saved_model_dir=None,
                               need_calibration=False,
                               is_dynamic_op=False):
        """General method to test trt_convert.TrtGraphConverter()."""
        output_graph_def = self._ConvertGraphV1(
            output_saved_model_dir=output_saved_model_dir,
            need_calibration=need_calibration,
            is_dynamic_op=is_dynamic_op,
            device=device)
        graph_defs_to_verify = [output_graph_def]

        if output_saved_model_dir:
            saved_model_graph_def = saved_model_utils.get_meta_graph_def(
                output_saved_model_dir, tag_constants.SERVING).graph_def
            self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef)
            graph_defs_to_verify.append(saved_model_graph_def)

        for graph_def in graph_defs_to_verify:
            node_name_to_op = {
                self._MayRemoveGraphSequenceNumber(node.name): node.op
                for node in graph_def.node
            }
            if device is not None and device.startswith("/CPU:"):
                self.assertEqual(
                    {
                        "add": "AddV2",
                        "v1": "Const",
                        "add_1": "AddV2",
                        "add_2": "AddV2",
                        "input1": "Placeholder",
                        "input2": "Placeholder",
                        "mul": "Mul",
                        "output": "Identity"
                    }, node_name_to_op)
            else:
                self.assertEqual(
                    {
                        "input1": "Placeholder",
                        "input2": "Placeholder",
                        "TRTEngineOp_000": "TRTEngineOp",
                        "output": "Identity"
                    }, node_name_to_op)

            if need_calibration:
                trt_engine_nodes = [
                    node for node in graph_def.node if node.op == "TRTEngineOp"
                ]
                if device is not None and device.startswith("/CPU:"):
                    self.assertEmpty(trt_engine_nodes)
                    return

                self.assertNotEmpty(trt_engine_nodes)
                for node in trt_engine_nodes:
                    self.assertTrue(len(node.attr["calibration_data"].s))
                # Run the calibrated graph.
                # TODO(laigd): consider having some input where the answer is different.
                with ops.Graph().as_default():
                    importer.import_graph_def(graph_def, name="")
                    with self.session(config=self._GetConfigProto()) as sess:
                        for test_data in range(10):
                            self.assertEqual(
                                (test_data + 1.0)**2 + test_data,
                                sess.run("output:0",
                                         feed_dict={
                                             "input1:0": [[[test_data]]],
                                             "input2:0": [[[test_data]]]
                                         }))
def quantize(saved_model_path: str,
             signature_keys=None,
             tags=None,
             output_directory=None,
             representative_dataset=None):
  """Quantizes the given SavedModel.

  Args:
    saved_model_path: Path to the saved model. When representative_dataset is
      not provided, this should be a model trained with QAT.
    signature_keys: List of keys identifying SignatureDef containing inputs and
      outputs.
    tags: Set of tags identifying the MetaGraphDef within the SavedModel to
      analyze.
    output_directory: The path to save the output SavedModel (must be an empty
      directory).
    representative_dataset: a generator that returns a dictionary in
      {input_name: input_tensor} format or a tuple with signature key and a
      dictionary in {input_name: input_tensor} format that feeds calibration
        data for quantizing model. This should be provided when the model is not
        a QAT model.

  Returns:
    A SavedModel object with TF quantization applied.

  Raises:
    ValueError: when representative_dataset is not provided for non-QAT model.
  """
  if tags is None:
    tags = set([tag_constants.SERVING])
  if signature_keys is None:
    signature_keys = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

  is_qat_saved_model = _is_qat_saved_model(saved_model_path)
  signatures = _get_signatures_from_saved_model(saved_model_path,
                                                signature_keys, tags)

  # Checks if the model is from QAT
  if representative_dataset is None and not is_qat_saved_model:
    raise ValueError(
        'When `representative_dataset` is not provided, the model should be '
        'trained with quantization-aware training (QAT).')

  if is_qat_saved_model:
    # Handle QAT models are supported.
    graph_def_serialized = (
        quantize_model_wrapper.quantize_qat_model(saved_model_path,
                                                  ','.join(signature_keys),
                                                  ','.join(tags)))
  else:
    # Handle PTQ models are supported with mocking calibration.
    graph_def_serialized = (
        quantize_model_wrapper.quantize_ptq_model_pre_calibration(
            saved_model_path, ','.join(signature_keys), ','.join(tags)))

    graph_def = graph_pb2.GraphDef()
    graph_def.ParseFromString(graph_def_serialized)

    float_model_dir = tempfile.mkdtemp()
    v1_builder = builder.SavedModelBuilder(float_model_dir)

    with session.Session(graph=ops.Graph()) as sess:
      for function_def in graph_def.library.function:
        for node_def in function_def.node_def:
          if node_def.op == 'CustomAggregator':
            node_def.attr['id'].s = uuid.uuid4().hex.encode('ascii')

      importer.import_graph_def(graph_def, name='')
      working_graph = ops.get_default_graph()
      graph_def = working_graph.as_graph_def()

      signatures = _fix_tensor_names(signatures, working_graph)
      if signatures is None:
        raise ValueError(
            "The input SavedModel doesn't contain a valid signature")

      v1_builder.add_meta_graph_and_variables(
          sess, [tag_constants.SERVING], signature_def_map=signatures)

    v1_builder.save()

    float_model = saved_model_load(float_model_dir)

    for sample in representative_dataset():
      # TODO(b/214311251): Add a test case with multiple signatures.
      if isinstance(sample, tuple):
        if not isinstance(sample[1], dict):
          raise ValueError('You need to provide a dictionary with input '
                           'names and values in the second argument in the '
                           'tuple')
        signature_key = sample[0]
        input_data_map = sample[1]
      elif isinstance(sample, dict):
        if len(signature_keys) > 1:
          raise ValueError('When the model has multiple signatures, you need '
                           'to provide a tuple with signature key and a '
                           'dictionary with input names and values')
        signature_key = signature_keys[0]
        input_data_map = sample
      else:
        raise ValueError('You need to provide either a dictionary with input '
                         'names and values or a tuple with signature key and a '
                         'dictionary with input names and values')
      func = float_model.signatures[signature_key]
      func(**input_data_map)

    for function_def in graph_def.library.function:
      for node_def in function_def.node_def:
        if node_def.op == 'CustomAggregator':
          node_id = node_def.attr['id'].s
          min_val = quantize_model_wrapper.get_min_from_calibrator(node_id)
          max_val = quantize_model_wrapper.get_max_from_calibrator(node_id)
          quantize_model_wrapper.clear_data_from_calibrator(node_id)
          node_def.attr['min'].f = float(min_val)
          node_def.attr['max'].f = float(max_val)

    calibrated_model_dir = tempfile.mkdtemp()
    v1_builder = builder.SavedModelBuilder(calibrated_model_dir)

    with session.Session(graph=ops.Graph()) as sess:
      importer.import_graph_def(graph_def, name='')
      working_graph = ops.get_default_graph()
      graph_def = working_graph.as_graph_def()

      v1_builder.add_meta_graph_and_variables(
          sess, [tag_constants.SERVING], signature_def_map=signatures)

    v1_builder.save()
    signatures = _get_signatures_from_saved_model(calibrated_model_dir,
                                                  signature_keys, tags)

    graph_def_serialized = (
        quantize_model_wrapper.quantize_ptq_model_post_calibration(
            calibrated_model_dir,
            ','.join(signature_keys),
            ','.join(tags),
        ))

  graph_def = graph_pb2.GraphDef()
  graph_def.ParseFromString(graph_def_serialized)

  if output_directory is None:
    output_directory = tempfile.mkdtemp()
  v1_builder = builder.SavedModelBuilder(output_directory)

  with session.Session(graph=ops.Graph()) as sess:
    importer.import_graph_def(graph_def, name='')
    working_graph = ops.get_default_graph()

    signatures = _fix_tensor_names(signatures, working_graph)
    if signatures is None:
      raise ValueError("The input SavedModel doesn't contain a valid signature")

    v1_builder.add_meta_graph_and_variables(
        sess, [tag_constants.SERVING], signature_def_map=signatures)

  v1_builder.save()

  return saved_model_load(output_directory)
  def _testFreezeGraph(self, saver_write_version):

    checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
    checkpoint_state_name = "checkpoint_state"
    input_graph_name = "input_graph.pb"
    output_graph_name = "output_graph.pb"

    # We'll create an input graph that has a single variable containing 1.0,
    # and that then multiplies it by 2.
    with ops.Graph().as_default():
      variable_node = variables.VariableV1(1.0, name="variable_node")
      output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
      sess = session.Session()
      init = variables.global_variables_initializer()
      sess.run(init)
      output = sess.run(output_node)
      self.assertNear(2.0, output, 0.00001)
      saver = saver_lib.Saver(write_version=saver_write_version)
      checkpoint_path = saver.save(
          sess,
          checkpoint_prefix,
          global_step=0,
          latest_filename=checkpoint_state_name)
      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)

    # We save out the graph to disk, and then call the const conversion
    # routine.
    input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
    input_saver_def_path = ""
    input_binary = False
    output_node_names = "output_node"
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
    clear_devices = False

    freeze_graph.freeze_graph(
        input_graph_path,
        input_saver_def_path,
        input_binary,
        checkpoint_path,
        output_node_names,
        restore_op_name,
        filename_tensor_name,
        output_graph_path,
        clear_devices,
        "",
        "",
        "",
        checkpoint_version=saver_write_version)

    # Now we make sure the variable is now a constant, and that the graph still
    # produces the expected result.
    with ops.Graph().as_default():
      output_graph_def = graph_pb2.GraphDef()
      with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = importer.import_graph_def(output_graph_def, name="")

      self.assertEqual(4, len(output_graph_def.node))
      for node in output_graph_def.node:
        self.assertNotEqual("VariableV2", node.op)
        self.assertNotEqual("Variable", node.op)

      with session.Session() as sess:
        output_node = sess.graph.get_tensor_by_name("output_node:0")
        output = sess.run(output_node)
        self.assertNear(2.0, output, 0.00001)
def function_def_to_graph(fdef, input_shapes=None):
    """Converts a FunctionDef to a FuncGraph (sub-class Graph).

  The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set.
  The input tensors are represented as placeholders.

  Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set
  by the caller.

  Args:
    fdef: FunctionDef.
    input_shapes: Optional. A list of TensorShape objects of the shapes of
      function inputs. Defaults to the function's "_input_shapes" attribute. If
      specified, its length must match length of `fdef.signature.input_arg`. If
      a shape is None, the corresponding input placeholder will have unknown
      shape.

  Returns:
    A FuncGraph.
  """
    func_graph = FuncGraph(fdef.signature.name)
    if input_shapes is None:
        input_shapes_attr = fdef.attr.get("_input_shapes", None)
        if input_shapes_attr is not None:
            input_shapes = input_shapes_attr.list.shape
    graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
        fdef, input_shapes)

    with func_graph.as_default():
        # Add all function nodes to the graph.
        importer.import_graph_def(graph_def, name="")

        # Initialize fields specific to FuncGraph.

        # inputs
        input_tensor_names = [
            nested_to_flat_tensor_name[arg.name]
            for arg in fdef.signature.input_arg
        ]
        func_graph.inputs = [
            func_graph.get_tensor_by_name(name) for name in input_tensor_names
        ]

        # outputs
        output_tensor_names = [
            nested_to_flat_tensor_name[fdef.ret[arg.name]]
            for arg in fdef.signature.output_arg
        ]
        func_graph.outputs = [
            func_graph.get_tensor_by_name(name) for name in output_tensor_names
        ]
        func_graph.control_outputs = [
            func_graph.get_operation_by_name(fdef.control_ret[ret_name])
            for ret_name in fdef.signature.control_output
        ]
        for node in graph_def.node:
            output_shapes = node.attr.get("_output_shapes", None)
            if output_shapes is not None:
                op = func_graph.get_operation_by_name(node.name)
                for output_index, shape in enumerate(output_shapes.list.shape):
                    op.outputs[output_index].set_shape(shape)
    return func_graph
Beispiel #52
0
import os
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import importer

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

pb_path = 'mymodel.pb'

run_meta = tf.RunMetadata()
with tf.Graph().as_default():
    output_graph_def = graph_pb2.GraphDef()
    with open(pb_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = importer.import_graph_def(output_graph_def, name="")
        print('model loaded!')
    all_keys = sorted(
        [n.name for n in tf.get_default_graph().as_graph_def().node])
    # for k in all_keys:
    #   print(k)

    with tf.Session() as sess:
        flops = tf.profiler.profile(
            tf.get_default_graph(),
            run_meta=run_meta,
            options=tf.profiler.ProfileOptionBuilder.float_operation())
        print("test flops:{:,}".format(flops.total_float_ops))
Beispiel #53
0
    def from_frozen_graph(cls,
                          graph_def_file,
                          input_arrays,
                          output_arrays,
                          input_shapes=None):
        """Creates a TocoConverter class from a file containing a frozen GraphDef.

    Args:
      graph_def_file: Full filepath of file containing TensorFlow GraphDef.
      input_arrays: List of input tensors to freeze graph with.
      output_arrays: List of output tensors to freeze graph with.
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
        None}). (default None)

    Returns:
      TocoConverter class.

    Raises:
      ValueError:
        Unable to parse input file.
        The graph is not frozen.
        input_arrays or output_arrays contains an invalid tensor name.
    """
        with _session.Session() as sess:
            sess.run(global_variables_initializer())

            # Read GraphDef from file.
            graph_def = _graph_pb2.GraphDef()
            with open(graph_def_file, "rb") as f:
                file_content = f.read()
            try:
                graph_def.ParseFromString(file_content)
            except (_text_format.ParseError, DecodeError):
                try:
                    print("Ignore 'tcmalloc: large alloc' warnings.")

                    if not isinstance(file_content, str):
                        if PY3:
                            file_content = file_content.decode('utf-8')
                        else:
                            file_content = file_content.encode('utf-8')
                    _text_format.Merge(file_content, graph_def)
                except (_text_format.ParseError, DecodeError):
                    raise ValueError("Unable to parse input file '{}'.".format(
                        graph_def_file))
            sess.graph.as_default()
            import_graph_def(graph_def, name="")

            # Get input and output tensors.
            input_tensors = get_tensors_from_tensor_names(
                sess.graph, input_arrays)
            output_tensors = get_tensors_from_tensor_names(
                sess.graph, output_arrays)
            set_tensor_shapes(input_tensors, input_shapes)

            # Check if graph is frozen.
            if not _is_frozen_graph(sess):
                raise ValueError(
                    "Please freeze the graph using freeze_graph.py.")

            # Create TocoConverter class.
            return cls(sess.graph_def, input_tensors, output_tensors)
Beispiel #54
0
 def testInvalidInputForReturnOperations(self):
   with ops.Graph().as_default():
     with self.assertRaises(TypeError) as e:
       importer.import_graph_def(self._MakeGraphDef(""), return_elements=[7])
     self.assertEqual("return_elements must be a list of strings.",
                      str(e.exception))
Beispiel #55
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not saver_lib.checkpoint_exists(input_checkpoint)):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(
          saver_def=input_saver_def, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor
      saver = saver_lib.Saver(
          var_list=var_list, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))

    variable_names_whitelist = (
        variable_names_whitelist.replace(" ", "").split(",")
        if variable_names_whitelist else None)
    variable_names_blacklist = (
        variable_names_blacklist.replace(" ", "").split(",")
        if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    #if not tf.gfile.Exists(FLAGS.output_dir):
    #    tf.gfile.MkDir(FLAGS.output_dir)

    if FLAGS.tensorrt:
        gpu_options = None
        print(trt.trt_convert.get_linked_tensorrt_version())
        gpu_options = cpb2.GPUOptions(
            per_process_gpu_memory_fraction=_GPU_MEM_FRACTION)
        sessconfig = cpb2.ConfigProto(gpu_options=gpu_options)
    else:
        sessconfig = None

    # Instantiate video capture object.
    cap = cv2.VideoCapture(1)

    # Set resolution
    # if resolution is not None:
    x_length, y_length = (1024, 1280)
    cap.set(3, x_length)  # 3 and 4 are OpenCV property IDs.
    cap.set(4, y_length)
    cap.read()
    x_new = int(cap.get(3))
    y_new = int(cap.get(4))
    print('Resolution is: {0} by {1}'.format(x_new, y_new))

    with tf.Graph().as_default(), tf.Session(config=sessconfig) as sess:

        #TODO - calculate these dimensions dynamically (they can't use None since TensorRT
        # needs precalculated dimensions

        # Defines place holder for the style image.
        style_img_ph = tf.placeholder(tf.float32,
                                      shape=[200, 1200, 3],
                                      name="style_img_ph")
        if FLAGS.style_square_crop:
            style_img_preprocessed = image_utils.center_crop_resize_image(
                style_img_ph, FLAGS.style_image_size)
        else:
            style_img_preprocessed = image_utils.resize_image(
                style_img_ph, FLAGS.style_image_size)

        # Defines place holder for the content image.
        content_img_ph = tf.placeholder(tf.float32,
                                        shape=[200, 1200, 3],
                                        name="content_img_ph")
        if FLAGS.content_square_crop:
            content_img_preprocessed = image_utils.center_crop_resize_image(
                content_img_ph, FLAGS.image_size)
        else:
            content_img_preprocessed = image_utils.resize_image(
                content_img_ph, FLAGS.image_size)

        # Defines the model.
        stylized_images, _, _, bottleneck_feat = build_model.build_model(
            content_img_preprocessed,
            style_img_preprocessed,
            trainable=False,
            is_training=False,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        print(stylized_images)
        print(bottleneck_feat)

        if tf.gfile.IsDirectory(FLAGS.checkpoint):
            checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
        else:
            checkpoint = FLAGS.checkpoint
            tf.logging.info(
                'loading latest checkpoint file: {}'.format(checkpoint))

        init_fn = slim.assign_from_checkpoint_fn(
            checkpoint, slim.get_variables_to_restore())
        sess.run([tf.local_variables_initializer()])
        init_fn(sess)

        tf.train.write_graph(sess.graph_def, '.', 'model.pbtxt')

        if FLAGS.tensorrt:
            # We use a built-in TF helper to export variables to constants
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess,  # The session is used to retrieve the weights
                tf.get_default_graph().as_graph_def(
                ),  # The graph_def is used to retrieve the nodes 
                [
                    'transformer/expand/conv3/conv/Sigmoid'
                ]  # The output node names are used to select the usefull nodes
            )

            trt_graph = trt.create_inference_graph(
                input_graph_def=output_graph_def,
                outputs=["transformer/expand/conv3/conv/Sigmoid"],
                max_workspace_size_bytes=5 << 30,
                max_batch_size=1,
                precision_mode=
                "FP16",  # TRT Engine precision "FP32","FP16" or "INT8"
                minimum_segment_size=10)

            bottleneck_feat_O, content_img_ph_O, stylized_images_O = importer.import_graph_def(
                graph_def=trt_graph,
                return_elements=[
                    "Conv/BiasAdd", "content_img_ph",
                    "transformer/expand/conv3/conv/Sigmoid"
                ])
            bottleneck_feat_O = bottleneck_feat_O.outputs[0]
            content_img_ph_O = content_img_ph_O.outputs[0]
            stylized_images_O = stylized_images_O.outputs[0]

            print("bottleneck opt:" + str(bottleneck_feat_O))
            print(content_img_ph_O)
            print(stylized_images_O)

        # Gets the list of the input style images.
        #style_img_list = tf.gfile.Glob(FLAGS.style_images_paths)
        # if len(style_img_list) > FLAGS.maximum_styles_to_evaluate:
        #    np.random.seed(1234)
        #    style_img_list = np.random.permutation(style_img_list)
        #    style_img_list = style_img_list[:FLAGS.maximum_styles_to_evaluate]

        # Gets list of input co ntent images.
        # content_img_list = tf.gfile.Glob(FLAGS.content_images_paths)

        # if style_i % 10 == 0:
        # tf.logging.info('Stylizing  %s with (%d) %s' %
        #                        ( content_img_name, style_i,
        #                         style_img_name))

        # for style_i, style_img_path in enumerate(style_img_list):
        # if style_i > FLAGS.maximum_styles_to_evaluate:
        #    break
        interpolation_weight = FLAGS.interpolation_weight
        activate_style = None

        while True:
            start = timer()
            #calculating style isn't the major FPS bottleneck
            current_style = Style.objects.filter(is_active=True).first()
            if (activate_style != current_style):
                activate_style = current_style
                style_img_path = activate_style.source_file.path
                print("current image is " + style_img_path)
                style_img_name = "bricks"
                style_image_np = image_utils.load_np_image_uint8(
                    style_img_path)[:, :, :3]
                style_image_np = cv2.resize(style_image_np, (1200, 200))

                # Saves preprocessed style image.
                style_img_croped_resized_np = sess.run(
                    style_img_preprocessed,
                    feed_dict={style_img_ph: style_image_np})
                #image_utils.save_np_image(style_img_croped_resized_np,
                #                          os.path.join(FLAGS.output_dir,
                #                                       '%s.jpg' % (style_img_name)))

                # Computes bottleneck features of the style prediction network for the
                # given style image.
                style_params = sess.run(
                    bottleneck_feat, feed_dict={style_img_ph: style_image_np})

            # for content_i, content_img_path in enumerate(content_img_list):
            ret, frame = cap.read()
            print("webcam image: " + str(frame.shape))
            #crop to get the weird 1200x200 format
            content_img_np = frame[500:700, 80:1280]
            #content_img_np = frame
            print("cropped image:" + str(content_img_np.shape))
            # content_img_np = image_utils.load_np_image_uint8(content_img_path)[:, :, :
            #                                                                        3]

            # content_img_name = os.path.basename(content_img_path)[:-4]
            content_img_name = "webcam"

            # Saves preprocessed content image.
            print("Input image:" + str(content_img_np.shape))
            inp_img_croped_resized_np = sess.run(
                content_img_preprocessed,
                feed_dict={content_img_ph: content_img_np})
            # image_utils.save_np_image(inp_img_croped_resized_np,
            #                          os.path.join(FLAGS.output_dir,
            #                                       '%s.jpg' % (content_img_name)))

            # Computes bottleneck features of the style prediction network for the
            # identity transform.
            identity_params = sess.run(
                bottleneck_feat, feed_dict={style_img_ph: content_img_np})

            # Interpolates between the parameters of the identity transform and
            # style parameters of the given style image.
            wi = interpolation_weight
            style_np = identity_params * (1 - wi) + style_params * wi
            if FLAGS.tensorrt:
                style_np = np.reshape(style_np, (1, 100, 1, 1))

                stylized_image_res = sess.run(stylized_images_O,
                                              feed_dict={
                                                  bottleneck_feat_O: style_np,
                                                  content_img_ph_O:
                                                  content_img_np
                                              })
            else:
                stylized_image_res = sess.run(stylized_images,
                                              feed_dict={
                                                  bottleneck_feat: style_np,
                                                  content_img_ph:
                                                  content_img_np
                                              })

            end = timer()
            print(end - start)
            print(stylized_image_res.shape)
            # Saves stylized image.
            # image_utils.save_np_image(
            #  stylized_image_res,
            #  os.path.join(FLAGS.output_dir, '%s_stylized_%s_%d.jpg' %
            #               (content_img_name, style_img_name, interp_i)))
            display_np_image(stylized_image_res, FLAGS.showFullScreen)
            print(stylized_image_res.shape)
            # if cv2.waitKey(1) & 0xFF == ord('q'):
            #  break
            #img_out = np.squeeze(stylized_image_res).astype(np.uint8)
            #img_out = cv2.cvtColor(img_out, cv2.COLOR_BGR2RGB)
            #cv2.imshow('frame', img_out)

            key = cv2.waitKey(10)
            print("Key " + str(key))
            if key == 27:
                break
            elif key == 192:
                FLAGS.showFullScreen = False
                cv2.setWindowProperty("window", cv2.WND_PROP_FULLSCREEN,
                                      cv2.WINDOW_NORMAL)
            elif (key == 233 or key == 193):
                FLAGS.showFullScreen = True
                cv2.setWindowProperty("window", cv2.WND_PROP_FULLSCREEN,
                                      cv2.WINDOW_FULLSCREEN)
            elif key == 60:  # less
                interpolation_weight -= 0.25
            elif key == 62:  # > more
                interpolation_weight += 0.25

            #if cv2.waitKey(1) & 0xFF == ord('q'):
            #    break

    cap.release()
    cv2.destroyAllWindows()
Beispiel #57
0
    def testStripUnusedMultipleInputs(self):
        input_graph_name = "input_graph.pb"
        output_graph_name = "output_graph.pb"

        # We'll create an input graph that multiplies two input nodes.
        with ops.Graph().as_default():
            constant_node1 = constant_op.constant(1.0, name="constant_node1")
            constant_node2 = constant_op.constant(2.0, name="constant_node2")
            input_node1 = math_ops.sub(constant_node1, 3.0, name="input_node1")
            input_node2 = math_ops.sub(constant_node2, 5.0, name="input_node2")
            output_node = math_ops.multiply(input_node1,
                                            input_node2,
                                            name="output_node")
            math_ops.add(output_node, 2.0, name="later_node")
            sess = session.Session()
            output = sess.run(output_node)
            self.assertNear(6.0, output, 0.00001)
            graph_io.write_graph(sess.graph, self.get_temp_dir(),
                                 input_graph_name)

        # We save out the graph to disk, and then call the const conversion
        # routine.
        input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
        input_binary = False
        input_node_names = "input_node1,input_node2"
        input_node_types = [
            dtypes.float32.as_datatype_enum, dtypes.float32.as_datatype_enum
        ]
        output_binary = True
        output_node_names = "output_node"
        output_graph_path = os.path.join(self.get_temp_dir(),
                                         output_graph_name)

        strip_unused_lib.strip_unused_from_files(
            input_graph_path, input_binary, output_graph_path, output_binary,
            input_node_names, output_node_names, input_node_types)

        # Now we make sure the variable is now a constant, and that the graph still
        # produces the expected result.
        with ops.Graph().as_default():
            output_graph_def = graph_pb2.GraphDef()
            with open(output_graph_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                _ = importer.import_graph_def(output_graph_def, name="")

            self.assertEqual(3, len(output_graph_def.node))
            for node in output_graph_def.node:
                self.assertNotEqual("Add", node.op)
                self.assertNotEqual("Sub", node.op)
                if node.name == input_node_names:
                    self.assertTrue("shape" in node.attr)

            with session.Session() as sess:
                input_node1 = sess.graph.get_tensor_by_name("input_node1:0")
                input_node2 = sess.graph.get_tensor_by_name("input_node2:0")
                output_node = sess.graph.get_tensor_by_name("output_node:0")
                output = sess.run(output_node,
                                  feed_dict={
                                      input_node1: [10.0],
                                      input_node2: [-5.0]
                                  })
                self.assertNear(-50.0, output, 0.00001)
Beispiel #58
0
 def testEmptyGraph(self):
   with ops.Graph().as_default() as g:
     init_version = g.version
     importer.import_graph_def(self._MakeGraphDef(""))
     self.assertEqual(init_version, g.version)
Beispiel #59
0
def create_inference_graph(input_graph_def,
                           outputs,
                           max_batch_size=1,
                           max_workspace_size_bytes=2 << 20,
                           precision_mode=TrtPrecisionMode.FP32,
                           minimum_segment_size=3,
                           is_dynamic_op=False,
                           maximum_cached_engines=1,
                           cached_engine_batches=None,
                           use_calibration=True,
                           input_saved_model_dir=None,
                           input_saved_model_tags=None,
                           output_saved_model_dir=None,
                           session_config=None):
    """Python wrapper for the TRT transformation.

  Args:
    input_graph_def: a GraphDef object containing a model to be transformed. If
      set to None, the graph will be read from the SavedModel loaded from
      input_saved_model_dir.
    outputs: list of tensors or node names for the model outputs. Only used when
      input_graph_def is not None.
    max_batch_size: max size for the input batch.
    max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
      engine can use at execution time. This corresponds to the 'workspaceSize'
      parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
    precision_mode: one of TrtPrecisionMode.supported_precision_modes().
    minimum_segment_size: the minimum number of nodes required for a subgraph to
      be replaced by TRTEngineOp.
    is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
      network and engine at run time.
    maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
      If the number of cached engines is already at max but none of them can
      serve the input, the TRTEngineOp will fall back to run the TF function
      based on which the TRTEngineOp is created.
    cached_engine_batches: a list of batch sizes used to create cached
      engines, only used when is_dynamic_op is True. The length of the list
      should be <= maximum_cached_engines, and the dynamic TRT op will
      use this list to determine the batch sizes of the cached engines, instead
      of making the decision on the fly. This is useful when we know the most
      common batch size(s) the application is going to generate.
    use_calibration: this argument is ignored if precision_mode is not INT8. If
      set to True, a calibration graph will be created to calibrate the missing
      ranges. The calibration graph must be converted to an inference graph
      using calib_graph_to_infer_graph() after running calibration. if set to
      False, quantization nodes will be expected for every tensor in the graph
      (exlcuding those which will be fused). If a range is missing, an error
      will occur. Please note that accuracy may be negatively affected if there
      is a mismatch between which tensors TRT quantizes and which tensors were
      trained with fake quantization.
    input_saved_model_dir: the directory to load the SavedModel which contains
      the input graph to transforms. Used only when input_graph_def is None.
    input_saved_model_tags: list of tags to load the SavedModel.
    output_saved_model_dir: if not None, construct a SavedModel using the
      returned GraphDef and save it to the specified directory. This option only
      works when the input graph is loaded from a SavedModel, i.e. when
      input_saved_model_dir is specified and input_graph_def is None.
    session_config: the ConfigProto used to create a Session. It's also used as
      a template to create a TRT-enabled ConfigProto for conversion. If not
      specified, a default ConfigProto will be used.

  Returns:
    A GraphDef transformed from input_graph_def (or the SavedModel graph def
    loaded from input_saved_model_dir, if input_graph_def is not present), where
    all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
    function is added for each of the subgraphs.

    If is_dynamic_op is True, each TRTEngineOp will contain a serialized
    subgraph GraphDef, which will be converted to a TRT engine at execution time
    and the TRT engine will be cached for future usage. A new TRT engine will be
    created each time when none of the cached engines match the input shapes. If
    it fails to execute the TRT engine or the number of cached engines reaches
    maximum_cached_engines, the op will fall back to call the corresponding TF
    function.

    If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
    engine created from the corresponding subgraph. No more engines will be
    created on the fly, and the op will fall back to call the corresponding TF
    function when it fails to execute the engine.

  Raises:
    ValueError: if the combination of the parameters is invalid.
    RuntimeError: if the TensorRT library version is incompatible.
  """
    compiled_version = get_linked_tensorrt_version()
    loaded_version = get_loaded_tensorrt_version()
    version_mismatch = False
    if loaded_version[0] < compiled_version[0]:
        tf_logging.error(
            "TensorRT version mismatch. Tensorflow was compiled against " +
            "TensorRT %s but library loaded from environment is TensorRT %s" %
            (".".join([str(x) for x in compiled_version]),
             ".".join([str(x) for x in loaded_version])) +
            ". Please make sure that correct version of TensorRT " +
            "is available in the system and added to ldconfig or LD_LIBRARY_PATH"
        )
        raise RuntimeError("Incompatible TensorRT library version")
    for i in zip(loaded_version, compiled_version):
        if i[0] != i[1]:
            tf_logging.warn("TensorRT mismatch. Compiled against version " +
                            "%s, but loaded %s. Things may not work" %
                            (".".join([str(x) for x in compiled_version]),
                             ".".join([str(x) for x in loaded_version])))
            version_mismatch = True
            break
    if not version_mismatch:
        tf_logging.info("Running against TensorRT version %s" %
                        ".".join([str(x) for x in loaded_version]))

    if session_config is None:
        session_config = config_pb2.ConfigProto()

    if input_saved_model_tags is None:
        input_saved_model_tags = [tag_constants.SERVING]
    saved_model_loader = None
    grappler_meta_graph_def = None

    if input_graph_def is None:
        # Read from SavedModel and freeze the graph if necessary.
        if input_saved_model_dir is None:
            raise ValueError(
                "input_graph_def and input_saved_model_dir cannot be "
                "both None")
        with ops.Graph().as_default():
            with session.Session(config=session_config) as sess:
                saved_model_loader = loader_impl.SavedModelLoader(
                    input_saved_model_dir)
                input_meta_graph_def = saved_model_loader.load(
                    sess, input_saved_model_tags)
                output_node_names = set()

                def _gather_names(tensor_info):
                    """Get the node names from a TensorInfo."""
                    return set([
                        tensor_info[key].name.split(":")[0]
                        for key in tensor_info
                    ])

                # Get input and outputs from all SignatureDef.
                for key in input_meta_graph_def.signature_def:
                    signature_def = input_meta_graph_def.signature_def[key]
                    output_node_names.update(
                        _gather_names(signature_def.inputs))
                    output_node_names.update(
                        _gather_names(signature_def.outputs))

                # Freeze the variables in the SavedModel graph and copy the frozen
                # graph over.
                frozen_graph_def = graph_util.convert_variables_to_constants(
                    sess, sess.graph.as_graph_def(add_shapes=True),
                    list(output_node_names))
                grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
                grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)

                # Copy the collections that are not variables.
                for key in input_meta_graph_def.collection_def:
                    # TODO(laigd): currently we use the collection key to filter out
                    # collections that depend on variable ops, but this may miss some
                    # other user-defined collections. A better way would be to use
                    # CollectionDef::NodeList for the filtering.
                    if key not in [
                            "variables", "local_variables", "model_variables",
                            "trainable_variables", "train_op",
                            "table_initializer"
                    ]:
                        grappler_meta_graph_def.collection_def[key].CopyFrom(
                            input_meta_graph_def.collection_def[key])

                # Copy other information.
                grappler_meta_graph_def.meta_info_def.CopyFrom(
                    input_meta_graph_def.meta_info_def)
                for key in input_meta_graph_def.signature_def:
                    grappler_meta_graph_def.signature_def[key].CopyFrom(
                        input_meta_graph_def.signature_def[key])
                # TODO(laigd): maybe add back AssetFileDef.
    else:
        if output_saved_model_dir is not None:
            raise ValueError("output_saved_model_dir cannot be set when "
                             "input_graph_def is set")
        # Create MetaGraphDef from input graph.
        graph = ops.Graph()
        with graph.as_default():
            importer.import_graph_def(input_graph_def, name="")
        grappler_meta_graph_def = saver.export_meta_graph(
            graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
        if outputs:
            output_collection = meta_graph_pb2.CollectionDef()
            output_list = output_collection.node_list.value
            for i in outputs:
                if isinstance(i, ops.Tensor):
                    output_list.append(_to_bytes(i.name))
                else:
                    output_list.append(_to_bytes(i))
            # TODO(laigd): use another key as the outputs are really not train_op.
            grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
                output_collection)

    # Create TRT-enabled ConfigProto.
    session_config_with_trt = config_pb2.ConfigProto()
    session_config_with_trt.CopyFrom(session_config)
    rewriter_config = None
    if (session_config_with_trt.HasField("graph_options") and
            session_config_with_trt.graph_options.HasField("rewrite_options")):
        rewriter_config = session_config_with_trt.graph_options.rewrite_options
    rewriter_config_with_trt = get_tensorrt_rewriter_config(
        rewriter_config, max_batch_size, max_workspace_size_bytes,
        precision_mode, minimum_segment_size, is_dynamic_op,
        maximum_cached_engines, cached_engine_batches, use_calibration)
    session_config_with_trt.graph_options.rewrite_options.CopyFrom(
        rewriter_config_with_trt)

    # Run Grappler.
    transformed_graph_def = tf_optimizer.OptimizeGraph(session_config_with_trt,
                                                       grappler_meta_graph_def,
                                                       graph_id=b"tf_graph")

    # Optionally write the transformed graphdef as SavedModel.
    if output_saved_model_dir is not None:
        saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
        with ops.Graph().as_default():
            importer.import_graph_def(transformed_graph_def, name="")
            # We don't use TRT here.
            with session.Session(config=session_config) as sess:
                saved_model_builder.add_meta_graph_and_variables(
                    sess,
                    input_saved_model_tags,
                    signature_def_map=grappler_meta_graph_def.signature_def)
        # Ignore other meta graphs from the input SavedModel.
        saved_model_builder.save()

    return transformed_graph_def
Beispiel #60
0
  def testMultipleImport(self):
    graph_def = self._MakeGraphDef("""
    node { name: 'A' op: 'IntOutput' }
    node { name: 'B' op: 'IntInput' input: 'A:0' }
    """)

    with ops.Graph().as_default():
      # Initial import
      a, b = importer.import_graph_def(
          graph_def,
          return_elements=["A", "B"],
          name="")
      self.assertEqual(a.name, "A")
      self.assertEqual(b.name, "B")
      self.assertEqual(list(b.inputs), [a.outputs[0]])

      # Repeat the same import
      a1, b1 = importer.import_graph_def(
          graph_def,
          return_elements=["A", "B"],
          name="")
      self.assertEqual(a1.name, "A_1")
      self.assertEqual(b1.name, "B_1")
      self.assertEqual(list(b1.inputs), [a1.outputs[0]])

      # Repeat the same import again
      a2, b2 = importer.import_graph_def(
          graph_def,
          return_elements=["A", "B"],
          name="")
      self.assertEqual(a2.name, "A_2")
      self.assertEqual(b2.name, "B_2")
      self.assertEqual(list(b2.inputs), [a2.outputs[0]])

      # Import with an already-used name
      a3, b3 = importer.import_graph_def(
          graph_def,
          return_elements=["A", "B"],
          name="A")
      self.assertEqual(a3.name, "A_3/A")
      self.assertEqual(b3.name, "A_3/B")
      self.assertEqual(list(b3.inputs), [a3.outputs[0]])

      # Import with existing de-duped node names
      a1_1, b1_1 = importer.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A_1' op: 'IntOutput' }
          node { name: 'B_1' op: 'IntInput' input: 'A_1:0' }
          """),
          return_elements=["A_1", "B_1"],
          name="")
      self.assertEqual(a1_1.name, "A_1_1")
      self.assertEqual(b1_1.name, "B_1_1")
      self.assertEqual(list(b1_1.inputs), [a1_1.outputs[0]])

      # Create a name scope and then import node with same name
      with ops.name_scope("foo"):
        constant_op.constant(1)
      foo, = importer.import_graph_def(
          self._MakeGraphDef("node { name: 'foo' op: 'IntOutput' }"),
          return_elements=["foo"],
          name="")
      self.assertEqual(foo.name, "foo_1")

      # Imported node name can't conflict with intermediate name scope (but can
      # conflict with outer scope and full name scope)
      with ops.name_scope("outer"):
        with ops.name_scope("inner"):
          c = constant_op.constant(1, name="c")
          self.assertEqual(c.op.name, "outer/inner/c")

      outer, inner, new_c, outer_inner, outer_inner_c = (
          importer.import_graph_def(
              self._MakeGraphDef(
                  "node { name: 'outer' op: 'IntOutput' }"
                  "node { name: 'inner' op: 'IntOutput' }"
                  "node { name: 'c' op: 'IntOutput' }"
                  "node { name: 'outer/inner' op: 'IntOutput' }"
                  "node { name: 'outer/inner/c' op: 'IntOutput' }"),
              return_elements=["outer", "inner", "c", "outer/inner",
                               "outer/inner/c"],
              name=""))
      self.assertEqual(outer.name, "outer_1")
      self.assertEqual(inner.name, "inner")
      self.assertEqual(new_c.name, "c")
      self.assertEqual(outer_inner.name, "outer/inner_1")
      self.assertEqual(outer_inner_c.name, "outer/inner/c_1")