def convert(self): """Converts a TensorFlow GraphDef based on instance variables. Returns: The converted data in serialized format. Either a TFLite Flatbuffer or a Graphviz graph depending on value in `output_format`. Raises: ValueError: Input shape is not specified. None value for dimension in input_tensor. """ # Checks dimensions in input tensor. for tensor in self._input_tensors: if not tensor.get_shape(): raise ValueError("Provide an input shape for input array '{0}'.".format( tensor_name(tensor))) shape = tensor.get_shape().as_list() if None in shape[1:]: raise ValueError( "None is only supported in the 1st dimension. Tensor '{0}' has " "invalid shape '{1}'.".format(tensor_name(tensor), shape)) elif shape[0] is None: self._set_batch_size(batch_size=1) # Get quantization stats. Ensures there is one stat per name if the stats # are specified. if self.quantized_input_stats: quantized_stats = [] invalid_stats = [] for tensor in self._input_tensors: name = tensor_name(tensor) if name in self.quantized_input_stats: quantized_stats.append(self.quantized_input_stats[name]) else: invalid_stats.append(name) if invalid_stats: raise ValueError("Quantization input stats are not available for input " "tensors '{0}'.".format(",".join(invalid_stats))) else: quantized_stats = None # Converts model. result = toco_convert( input_data=self._graph_def, input_tensors=self._input_tensors, output_tensors=self._output_tensors, inference_type=self.inference_type, inference_input_type=self.inference_input_type, input_format=constants.TENSORFLOW_GRAPHDEF, output_format=self.output_format, quantized_input_stats=quantized_stats, default_ranges_stats=self.default_ranges_stats, drop_control_dependency=self.drop_control_dependency, reorder_across_fake_quant=self.reorder_across_fake_quant, change_concat_input_ranges=self.change_concat_input_ranges, allow_custom_ops=self.allow_custom_ops, quantize_weights=self.quantize_weights) return result
def convert(self): """Converts a TensorFlow GraphDef based on instance variables. Returns: The converted data in serialized format. Either a TFLite Flatbuffer or a Graphviz graph depending on value in `output_format`. Raises: ValueError: Input shape is not specified. None value for dimension in input_tensor. """ # Checks dimensions in input tensor. for tensor in self._input_tensors: if not tensor.get_shape(): raise ValueError("Provide an input shape for input array '{0}'.".format( tensor_name(tensor))) shape = tensor.get_shape().as_list() if None in shape[1:]: raise ValueError( "None is only supported in the 1st dimension. Tensor '{0}' has " "invalid shape '{1}'.".format(tensor_name(tensor), shape)) elif shape[0] is None: self._set_batch_size(batch_size=1) # Get quantization stats. Ensures there is one stat per name if the stats # are specified. if self.quantized_input_stats: quantized_stats = [] invalid_stats = [] for tensor in self._input_tensors: name = tensor_name(tensor) if name in self.quantized_input_stats: quantized_stats.append(self.quantized_input_stats[name]) else: invalid_stats.append(name) if invalid_stats: raise ValueError("Quantization input stats are not available for input " "tensors '{0}'.".format(",".join(invalid_stats))) else: quantized_stats = None # Converts model. result = toco_convert( input_data=self._graph_def, input_tensors=self._input_tensors, output_tensors=self._output_tensors, inference_type=self.inference_type, inference_input_type=self.inference_input_type, input_format=constants.TENSORFLOW_GRAPHDEF, output_format=self.output_format, quantized_input_stats=quantized_stats, default_ranges_stats=self.default_ranges_stats, drop_control_dependency=self.drop_control_dependency, reorder_across_fake_quant=self.reorder_across_fake_quant, change_concat_input_ranges=self.change_concat_input_ranges, allow_custom_ops=self.allow_custom_ops) return result
def saved_model_to_frozen_graphdef( saved_model_dir, output_file_model, output_file_flags, input_arrays=None, input_shapes=None, output_arrays=None, tag_set=None, signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, batch_size=1): """Converts a SavedModel to a frozen graph. Writes graph to tmp directory. Stores frozen graph and command line flags in the tmp directory. Args: saved_model_dir: SavedModel directory to convert. output_file_model: Full file path to save frozen graph. output_file_flags: Full file path to save ModelFlags. input_arrays: List of input tensors to freeze graph with. Uses input arrays from SignatureDef when none are provided. (default None) input_shapes: Map of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). Automatically determined when input shapes is None (e.g., {"foo" : None}). (default None) output_arrays: List of output tensors to freeze graph with. Uses output arrays from SignatureDef when none are provided. (default None) tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. All tags in the tag set must be present. (default "serve") signature_key: Key identifying SignatureDef containing inputs and outputs. batch_size: Batch size for the model. Replaces the first dimension of an input size array if undefined. (default 1) Returns: None. Raises: ValueError: Unable to convert to frozen graph. """ frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, signature_key, batch_size) # Initialize model flags. model = model_flags_pb2.ModelFlags() for input_tensor in in_tensors: input_array = model.input_arrays.add() input_array.name = convert.tensor_name(input_tensor) input_array.shape.dims.extend(map(int, input_tensor.get_shape())) for output_tensor in out_tensors: model.output_arrays.append(convert.tensor_name(output_tensor)) # Write model and ModelFlags to file. ModelFlags contain input array and # output array information that is parsed from the SignatureDef and used for # analysis by TOCO. _write_and_flush_file(output_file_model, frozen_graph_def.SerializeToString()) _write_and_flush_file(output_file_flags, model.SerializeToString())
def saved_model_to_frozen_graphdef( saved_model_dir, output_file_model, output_file_flags, input_arrays=None, input_shapes=None, output_arrays=None, tag_set=None, signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, batch_size=1): """Converts a SavedModel to a frozen graph. Writes graph to tmp directory. Stores frozen graph and command line flags in the tmp directory. Args: saved_model_dir: SavedModel directory to convert. output_file_model: Full file path to save frozen graph. output_file_flags: Full file path to save ModelFlags. input_arrays: List of input tensors to freeze graph with. Uses input arrays from SignatureDef when none are provided. (default None) input_shapes: Map of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). Automatically determined when input shapes is None (e.g., {"foo" : None}). (default None) output_arrays: List of output tensors to freeze graph with. Uses output arrays from SignatureDef when none are provided. (default None) tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. All tags in the tag set must be present. (default "serve") signature_key: Key identifying SignatureDef containing inputs and outputs. batch_size: Batch size for the model. Replaces the first dimension of an input size array if undefined. (default 1) Returns: None. Raises: ValueError: Unable to convert to frozen graph. """ frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, signature_key, batch_size) # Initialize model flags. model = model_flags_pb2.ModelFlags() for input_tensor in in_tensors: input_array = model.input_arrays.add() input_array.name = convert.tensor_name(input_tensor) input_array.shape.dims.extend(map(int, input_tensor.get_shape())) for output_tensor in out_tensors: model.output_arrays.append(convert.tensor_name(output_tensor)) # Write model and ModelFlags to file. ModelFlags contain input array and # output array information that is parsed from the SignatureDef and used for # analysis by TOCO. _write_and_flush_file(output_file_model, frozen_graph_def.SerializeToString()) _write_and_flush_file(output_file_flags, model.SerializeToString())
def get_input_arrays(self): """Returns a list of the names of the input tensors. Returns: List of strings. """ return [tensor_name(tensor) for tensor in self._input_tensors]
def get_tensors_from_tensor_names(graph, tensor_names): """Gets the Tensors associated with the `tensor_names` in the provided graph. Args: graph: TensorFlow Graph. tensor_names: List of strings that represent names of tensors in the graph. Returns: A list of Tensor objects in the same order the names are provided. Raises: ValueError: tensor_names contains an invalid tensor name. """ # Get the list of all of the tensors. tensor_name_to_tensor = { tensor_name(tensor): tensor for op in graph.get_operations() for tensor in op.values() } # Get the tensors associated with tensor_names. tensors = [] invalid_tensors = [] for name in tensor_names: tensor = tensor_name_to_tensor.get(name) if tensor is None: invalid_tensors.append(name) else: tensors.append(tensor) # Throw ValueError if any user input names are not valid tensors. if invalid_tensors: raise ValueError("Invalid tensors '{}' were found.".format( ",".join(invalid_tensors))) return tensors
def get_tensors_from_tensor_names(graph, tensor_names): """Gets the Tensors associated with the `tensor_names` in the provided graph. Args: graph: TensorFlow Graph. tensor_names: List of strings that represent names of tensors in the graph. Returns: A list of Tensor objects in the same order the names are provided. Raises: ValueError: tensor_names contains an invalid tensor name. """ # Get the list of all of the tensors. tensor_name_to_tensor = { tensor_name(tensor): tensor for op in graph.get_operations() for tensor in op.values() } # Get the tensors associated with tensor_names. tensors = [] invalid_tensors = [] for name in tensor_names: tensor = tensor_name_to_tensor.get(name) if tensor is None: invalid_tensors.append(name) else: tensors.append(tensor) # Throw ValueError if any user input names are not valid tensors. if invalid_tensors: raise ValueError("Invalid tensors '{}' were found.".format( ",".join(invalid_tensors))) return tensors
def get_input_arrays(self): """Returns a list of the names of the input tensors. Returns: List of strings. """ return [tensor_name(tensor) for tensor in self._input_tensors]
def from_session(cls, sess, input_tensors, output_tensors, freeze_variables=False): """Creates a TocoConverter class from a TensorFlow Session. Args: sess: TensorFlow Session. input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). freeze_variables: Boolean indicating whether the variables need to be converted into constants via the freeze_graph.py script. (default False) Returns: TocoConverter class. """ # Get GraphDef. if freeze_variables: sess.run(global_variables_initializer()) output_arrays = [tensor_name(tensor) for tensor in output_tensors] graph_def = tf_graph_util.convert_variables_to_constants( sess, sess.graph_def, output_arrays) else: graph_def = sess.graph_def # Create TocoConverter class. return cls(graph_def, input_tensors, output_tensors)
def from_session(cls, sess, input_tensors, output_tensors, freeze_variables=False): """Creates a TocoConverter class from a TensorFlow Session. Args: sess: TensorFlow Session. input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). freeze_variables: Boolean indicating whether the variables need to be converted into constants via the freeze_graph.py script. (default False) Returns: TocoConverter class. """ # Get GraphDef. if freeze_variables: sess.run(global_variables_initializer()) output_arrays = [tensor_name(tensor) for tensor in output_tensors] graph_def = tf_graph_util.convert_variables_to_constants( sess, sess.graph_def, output_arrays) else: graph_def = sess.graph_def # Create TocoConverter class. return cls(graph_def, input_tensors, output_tensors)
def set_tensor_shapes(tensors, shapes): """Sets Tensor shape for each tensor if the shape is defined. Args: tensors: TensorFlow ops.Tensor. shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). """ if shapes: for tensor in tensors: shape = shapes.get(tensor_name(tensor)) if shape is not None: tensor.set_shape(shape)
def set_tensor_shapes(tensors, shapes): """Sets Tensor shape for each tensor if the shape is defined. Args: tensors: TensorFlow ops.Tensor. shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). """ if shapes: for tensor in tensors: shape = shapes.get(tensor_name(tensor)) if shape is not None: tensor.set_shape(shape)
def _freeze_graph(sess, output_tensors): """Returns a frozen GraphDef. Freezes a graph with Variables in it. Otherwise the existing GraphDef is returned. Args: sess: TensorFlow Session. output_tensors: List of output tensors (only .name is used from this). Returns: Frozen GraphDef. """ if not _is_frozen_graph(sess): sess.run(global_variables_initializer()) output_arrays = [tensor_name(tensor) for tensor in output_tensors] return tf_graph_util.convert_variables_to_constants( sess, sess.graph_def, output_arrays) else: return sess.graph_def
def _freeze_graph(sess, output_tensors): """Returns a frozen GraphDef. Freezes a graph with Variables in it. Otherwise the existing GraphDef is returned. Args: sess: TensorFlow Session. output_tensors: List of output tensors (only .name is used from this). Returns: Frozen GraphDef. """ if not _is_frozen_graph(sess): sess.run(global_variables_initializer()) output_arrays = [tensor_name(tensor) for tensor in output_tensors] return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def, output_arrays) else: return sess.graph_def