def testFreezeGraph(self): in_tensor = array_ops.placeholder( shape=[1, 16, 16, 3], dtype=dtypes.float32) var = variable_scope.get_variable( 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = in_tensor + var sess = session.Session() sess.run(_global_variables_initializer()) # Convert model and ensure model is not None. converter = lite.TFLiteConverter.from_session(sess, [in_tensor], [out_tensor]) tflite_model = converter.convert() self.assertTrue(tflite_model) # Check values from converted model. interpreter = Interpreter(model_content=tflite_model) interpreter.allocate_tensors() input_details = interpreter.get_input_details() self.assertEqual(1, len(input_details)) self.assertEqual('Placeholder', input_details[0]['name']) self.assertEqual(np.float32, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((0., 0.), input_details[0]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('add', output_details[0]['name']) self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization'])
def testFreezeGraph(self): in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.float32) var = variable_scope.get_variable('weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = in_tensor + var sess = session.Session() sess.run(_global_variables_initializer()) # Convert model and ensure model is not None. converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) tflite_model = converter.convert() self.assertTrue(tflite_model) # Check values from converted model. interpreter = Interpreter(model_content=tflite_model) interpreter.allocate_tensors() input_details = interpreter.get_input_details() self.assertEqual(1, len(input_details)) self.assertEqual('Placeholder', input_details[0]['name']) self.assertEqual(np.float32, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((0., 0.), input_details[0]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('add', output_details[0]['name']) self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization'])
def _freeze_graph(sess, output_tensors): """Returns a frozen GraphDef. Freezes a graph with Variables in it. Otherwise the existing GraphDef is returned. Args: sess: TensorFlow Session. output_tensors: List of output tensors (only .name is used from this). Returns: Frozen GraphDef. """ if not _is_frozen_graph(sess): sess.run(_global_variables_initializer()) output_arrays = [_tensor_name(tensor) for tensor in output_tensors] return _tf_graph_util.convert_variables_to_constants( sess, sess.graph_def, output_arrays) else: return sess.graph_def
def from_frozen_graph(cls, graph_def_file, input_arrays, output_arrays, input_shapes=None): """Creates a TocoConverter class from a file containing a frozen GraphDef. Args: graph_def_file: Full filepath of file containing frozen GraphDef. input_arrays: List of input tensors to freeze graph with. output_arrays: List of output tensors to freeze graph with. input_shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). Automatically determined when input shapes is None (e.g., {"foo" : None}). (default None) Returns: TocoConverter class. Raises: ValueError: Unable to parse input file. The graph is not frozen. input_arrays or output_arrays contains an invalid tensor name. """ with _session.Session() as sess: sess.run(_global_variables_initializer()) # Read GraphDef from file. graph_def = _graph_pb2.GraphDef() with open(graph_def_file, "rb") as f: file_content = f.read() try: graph_def.ParseFromString(file_content) except (_text_format.ParseError, DecodeError): try: print("Ignore 'tcmalloc: large alloc' warnings.") if not isinstance(file_content, str): if PY3: file_content = file_content.decode('utf-8') else: file_content = file_content.encode('utf-8') _text_format.Merge(file_content, graph_def) except (_text_format.ParseError, DecodeError): raise ValueError( "Unable to parse input file '{}'.".format(graph_def_file)) sess.graph.as_default() _import_graph_def(graph_def, name="") # Get input and output tensors. input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays) _set_tensor_shapes(input_tensors, input_shapes) # Check if graph is frozen. if not _is_frozen_graph(sess): raise ValueError("Please freeze the graph using freeze_graph.py.") # Create TocoConverter class. return cls(sess.graph_def, input_tensors, output_tensors)