def _convert_tf1_model(flags): """Calls function to convert the TensorFlow 1.X model into a TFLite model. Args: flags: argparse.Namespace object. Raises: ValueError: Invalid flags. """ # Register custom opdefs before converter object creation. if flags.custom_opdefs: register_custom_opdefs(_parse_array(flags.custom_opdefs)) # Create converter. converter = _get_tflite_converter(flags) if flags.inference_type: converter.inference_type = _parse_inference_type( flags.inference_type, "inference_type") if flags.inference_input_type: converter.inference_input_type = _parse_inference_type( flags.inference_input_type, "inference_input_type") if flags.output_format: converter.output_format = _toco_flags_pb2.FileFormat.Value( flags.output_format) if flags.mean_values and flags.std_dev_values: input_arrays = converter.get_input_arrays() std_dev_values = _parse_array(flags.std_dev_values, type_fn=float) # In quantized inference, mean_value has to be integer so that the real # value 0.0 is exactly representable. if converter.inference_type == dtypes.float32: mean_values = _parse_array(flags.mean_values, type_fn=float) else: mean_values = _parse_array(flags.mean_values, type_fn=int) quant_stats = list(zip(mean_values, std_dev_values)) if ((not flags.input_arrays and len(input_arrays) > 1) or (len(input_arrays) != len(quant_stats))): raise ValueError( "Mismatching --input_arrays, --std_dev_values, and " "--mean_values. The flags must have the same number of " "items. The current input arrays are '{0}'. " "--input_arrays must be present when specifying " "--std_dev_values and --mean_values with multiple input " "tensors in order to map between names and " "values.".format(",".join(input_arrays))) converter.quantized_input_stats = dict( list(zip(input_arrays, quant_stats))) if (flags.default_ranges_min is not None) and (flags.default_ranges_max is not None): converter.default_ranges_stats = (flags.default_ranges_min, flags.default_ranges_max) if flags.drop_control_dependency: converter.drop_control_dependency = flags.drop_control_dependency if flags.reorder_across_fake_quant: converter.reorder_across_fake_quant = flags.reorder_across_fake_quant if flags.change_concat_input_ranges: converter.change_concat_input_ranges = ( flags.change_concat_input_ranges == "TRUE") if flags.allow_custom_ops: converter.allow_custom_ops = flags.allow_custom_ops if flags.target_ops: ops_set_options = lite.OpsSet.get_options() converter.target_spec.supported_ops = set() for option in six.ensure_str(flags.target_ops).split(","): if option not in ops_set_options: raise ValueError("Invalid value for --target_ops. Options: " "{0}".format(",".join(ops_set_options))) converter.target_spec.supported_ops.add(lite.OpsSet(option)) if flags.experimental_select_user_tf_ops: if lite.OpsSet.SELECT_TF_OPS not in converter.target_spec.supported_ops: raise ValueError( "--experimental_select_user_tf_ops can only be set if " "--target_ops contains SELECT_TF_OPS.") user_op_set = set() for op_name in six.ensure_str( flags.experimental_select_user_tf_ops).split(","): user_op_set.add(op_name) converter.target_spec.experimental_select_user_tf_ops = list( user_op_set) if flags.post_training_quantize: converter.optimizations = [lite.Optimize.DEFAULT] if converter.inference_type != dtypes.float32: print( "--post_training_quantize quantizes a graph of inference_type " "FLOAT. Overriding inference_type to FLOAT.") converter.inference_type = dtypes.float32 if flags.quantize_to_float16: converter.target_spec.supported_types = [dtypes.float16] if not flags.post_training_quantize: print("--quantize_to_float16 will only take effect with the " "--post_training_quantize flag enabled.") if flags.dump_graphviz_dir: converter.dump_graphviz_dir = flags.dump_graphviz_dir if flags.dump_graphviz_video: converter.dump_graphviz_vode = flags.dump_graphviz_video if flags.conversion_summary_dir: converter.conversion_summary_dir = flags.conversion_summary_dir converter.experimental_new_converter = flags.experimental_new_converter if flags.experimental_new_quantizer is not None: converter.experimental_new_quantizer = flags.experimental_new_quantizer # Convert model. output_data = converter.convert() with open(flags.output_file, "wb") as f: f.write(six.ensure_binary(output_data))
def _convert_tf1_model(flags): """Calls function to convert the TensorFlow 1.X model into a TFLite model. Args: flags: argparse.Namespace object. Raises: ValueError: Invalid flags. """ # Create converter. converter = _get_toco_converter(flags) if flags.inference_type: converter.inference_type = _parse_inference_type( flags.inference_type, "inference_type") if flags.inference_input_type: converter.inference_input_type = _parse_inference_type( flags.inference_input_type, "inference_input_type") if flags.output_format: converter.output_format = _toco_flags_pb2.FileFormat.Value( flags.output_format) if flags.mean_values and flags.std_dev_values: input_arrays = converter.get_input_arrays() std_dev_values = _parse_array(flags.std_dev_values, type_fn=float) # In quantized inference, mean_value has to be integer so that the real # value 0.0 is exactly representable. if converter.inference_type == lite_constants.QUANTIZED_UINT8: mean_values = _parse_array(flags.mean_values, type_fn=int) else: mean_values = _parse_array(flags.mean_values, type_fn=float) quant_stats = list(zip(mean_values, std_dev_values)) if ((not flags.input_arrays and len(input_arrays) > 1) or (len(input_arrays) != len(quant_stats))): raise ValueError( "Mismatching --input_arrays, --std_dev_values, and " "--mean_values. The flags must have the same number of " "items. The current input arrays are '{0}'. " "--input_arrays must be present when specifying " "--std_dev_values and --mean_values with multiple input " "tensors in order to map between names and " "values.".format(",".join(input_arrays))) converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) if (flags.default_ranges_min is not None) and (flags.default_ranges_max is not None): converter.default_ranges_stats = (flags.default_ranges_min, flags.default_ranges_max) if flags.drop_control_dependency: converter.drop_control_dependency = flags.drop_control_dependency if flags.reorder_across_fake_quant: converter.reorder_across_fake_quant = flags.reorder_across_fake_quant if flags.change_concat_input_ranges: converter.change_concat_input_ranges = ( flags.change_concat_input_ranges == "TRUE") if flags.allow_custom_ops: converter.allow_custom_ops = flags.allow_custom_ops if flags.target_ops: ops_set_options = lite.OpsSet.get_options() converter.target_ops = set() for option in flags.target_ops.split(","): if option not in ops_set_options: raise ValueError("Invalid value for --target_ops. Options: " "{0}".format(",".join(ops_set_options))) converter.target_ops.add(lite.OpsSet(option)) if flags.post_training_quantize: converter.post_training_quantize = flags.post_training_quantize if converter.inference_type == lite_constants.QUANTIZED_UINT8: print( "--post_training_quantize quantizes a graph of inference_type " "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.") converter.inference_type = lite_constants.FLOAT if flags.dump_graphviz_dir: converter.dump_graphviz_dir = flags.dump_graphviz_dir if flags.dump_graphviz_video: converter.dump_graphviz_vode = flags.dump_graphviz_video if flags.ev_quant: converter.ev_quant = flags.ev_quant # Convert model. output_data = converter.convert() with open(flags.output_file, "wb") as f: f.write(output_data)