def convert(self): """Converts a TensorFlow GraphDef based on instance variables. Returns: The converted data in serialized format. Either a TFLite Flatbuffer or a Graphviz graph depending on value in `output_format`. Raises: ValueError: Input shape is not specified. None value for dimension in input_tensor. """ # Checks dimensions in input tensor. for tensor in self._input_tensors: if not tensor.get_shape(): raise ValueError("Provide an input shape for input array '{0}'.".format( tensor_name(tensor))) shape = tensor.get_shape().as_list() if None in shape[1:]: raise ValueError( "None is only supported in the 1st dimension. Tensor '{0}' has " "invalid shape '{1}'.".format(tensor_name(tensor), shape)) elif shape[0] is None: self._set_batch_size(batch_size=1) # Get quantization stats. Ensures there is one stat per name if the stats # are specified. if self.quantized_input_stats: quantized_stats = [] invalid_stats = [] for tensor in self._input_tensors: name = tensor_name(tensor) if name in self.quantized_input_stats: quantized_stats.append(self.quantized_input_stats[name]) else: invalid_stats.append(name) if invalid_stats: raise ValueError("Quantization input stats are not available for input " "tensors '{0}'.".format(",".join(invalid_stats))) else: quantized_stats = None # Converts model. result = toco_convert( input_data=self._graph_def, input_tensors=self._input_tensors, output_tensors=self._output_tensors, inference_type=self.inference_type, inference_input_type=self.inference_input_type, input_format=constants.TENSORFLOW_GRAPHDEF, output_format=self.output_format, quantized_input_stats=quantized_stats, default_ranges_stats=self.default_ranges_stats, drop_control_dependency=self.drop_control_dependency, reorder_across_fake_quant=self.reorder_across_fake_quant, change_concat_input_ranges=self.change_concat_input_ranges, allow_custom_ops=self.allow_custom_ops, quantize_weights=self.quantize_weights) return result
def convert(self): """Converts a TensorFlow GraphDef based on instance variables. Returns: The converted data in serialized format. Either a TFLite Flatbuffer or a Graphviz graph depending on value in `output_format`. Raises: ValueError: None value for dimension in input_tensor. """ # Checks dimensions in input tensor. for tensor in self._input_tensors: shape = tensor.get_shape().as_list() if None in shape[1:]: raise ValueError( "None is only supported in the 1st dimension. Tensor '{0}' has " "invalid shape '{1}'.".format(tensor.name, shape)) elif shape[0] is None: self._set_batch_size(batch_size=1) # Converts model. result = toco_convert( input_data=self._graph_def, input_tensors=self._input_tensors, output_tensors=self._output_tensors, inference_type=self.inference_type, input_format=constants.TENSORFLOW_GRAPHDEF, output_format=self.output_format, quantized_input_stats=self.quantized_input_stats, drop_control_dependency=self.drop_control_dependency) return result
def testBasic(self): in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = in_tensor + in_tensor sess = session.Session() # Try running on valid graph tflite_model = convert.toco_convert(sess.graph_def, [in_tensor], [out_tensor]) self.assertTrue(tflite_model)
def testQuantization(self): in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor, min=0., max=1.) sess = session.Session() tflite_model = convert.toco_convert( sess.graph_def, [in_tensor], [out_tensor], inference_type=lite_constants.QUANTIZED_UINT8, quantized_input_stats=[(0., 1.)]) self.assertTrue(tflite_model)
def convert(self): """Converts a TensorFlow GraphDef based on instance variables. Returns: The converted data in serialized format. Either a TFLite Flatbuffer or a Graphviz graph depending on value in `output_format`. Raises: ValueError: Input shape is not specified. None value for dimension in input_tensor. """ # Checks dimensions in input tensor. for tensor in self._input_tensors: if not tensor.get_shape(): raise ValueError( "Provide an input shape for input array '{0}'.".format( tensor_name(tensor))) shape = tensor.get_shape().as_list() if None in shape[1:]: raise ValueError( "None is only supported in the 1st dimension. Tensor '{0}' has " "invalid shape '{1}'.".format(tensor_name(tensor), shape)) elif shape[0] is None: self._set_batch_size(batch_size=1) # Get quantization stats. Ensures there is one stat per name if the stats # are specified. if self.quantized_input_stats: quantized_stats = [] invalid_stats = [] for tensor in self._input_tensors: name = tensor_name(tensor) if name in self.quantized_input_stats: quantized_stats.append(self.quantized_input_stats[name]) else: invalid_stats.append(name) if invalid_stats: raise ValueError( "Quantization input stats are not available for input " "tensors '{0}'.".format(",".join(invalid_stats))) else: quantized_stats = None # Converts model. result = toco_convert( input_data=self._graph_def, input_tensors=self._input_tensors, output_tensors=self._output_tensors, inference_type=self.inference_type, inference_input_type=self.inference_input_type, input_format=constants.TENSORFLOW_GRAPHDEF, output_format=self.output_format, quantized_input_stats=quantized_stats, default_ranges_stats=self.default_ranges_stats, drop_control_dependency=self.drop_control_dependency, reorder_across_fake_quant=self.reorder_across_fake_quant, change_concat_input_ranges=self.change_concat_input_ranges, allow_custom_ops=self.allow_custom_ops) return result
def tflite_from_saved_model( saved_model_dir, output_file=None, input_arrays=None, input_shapes=None, output_arrays=None, tag_set=None, signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, batch_size=1, inference_type=lite_constants.FLOAT, input_format=lite_constants.TENSORFLOW_GRAPHDEF, output_format=lite_constants.TFLITE, quantized_input_stats=None, drop_control_dependency=True): """Converts a SavedModel to TFLite FlatBuffer. Args: saved_model_dir: SavedModel directory to convert. output_file: File path to write result TFLite FlatBuffer. input_arrays: List of input tensors to freeze graph with. Uses input arrays from SignatureDef when none are provided. (default None) input_shapes: Map of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). Automatically determined when input shapes is None (e.g., {"foo" : None}). (default None) output_arrays: List of output tensors to freeze graph with. Uses output arrays from SignatureDef when none are provided. (default None) tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. All tags in the tag set must be present. (default "serve") signature_key: Key identifying SignatureDef containing inputs and outputs. batch_size: Batch size for the model. Replaces the first dimension of an input size array if undefined. (default 1) inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). output_format: Type of data to write (currently must be TFLITE or GRAPHVIZ_DOT) quantized_input_stats: For each member of input_tensors the mean and std deviation of training data. Only needed if `inference_type` is `QUANTIZED_UINT8`. drop_control_dependency: Drops control dependencies silently. This is due to tf lite not supporting control dependencies. Returns: The converted data. For example if tflite was the destination, then this will be a tflite flatbuffer in a bytes array. Raises: ValueError: Unable to convert to frozen graph. """ frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, signature_key, batch_size) result = convert.toco_convert( input_data=frozen_graph_def, input_tensors=in_tensors, output_tensors=out_tensors, inference_type=inference_type, input_format=input_format, output_format=output_format, quantized_input_stats=quantized_input_stats, drop_control_dependency=drop_control_dependency) if output_file is not None: with gfile.Open(output_file, "wb") as f: f.write(result) logging.info("Successfully converted to: %s", output_file) return result