def main(_):
    parser = argparse.ArgumentParser(
        description='Randomize weights in a tflite file.')
    parser.add_argument('--input_tflite_file',
                        type=str,
                        required=True,
                        help='Full path name to the input tflite file.')
    parser.add_argument(
        '--output_tflite_file',
        type=str,
        required=True,
        help='Full path name to the output randomized tflite file.')
    parser.add_argument(
        '--random_seed',
        type=str,
        required=False,
        default=0,
        help='Input to the random number generator. The default value is 0.')
    args = parser.parse_args()

    # Read the model
    input_model = flatbuffer_utils.read_model(args.input_tflite_file)
    # Invoke the randomize weights function
    output_model = flatbuffer_utils.randomize_weights(input_model,
                                                      args.random_seed)
    # Write the model
    flatbuffer_utils.write_model(output_model, args.output_tflite_file)
    def testXxdOutputToBytes(self):
        # 1. SETUP
        # Define the initial model
        initial_model = test_utils.build_mock_model()
        initial_bytes = flatbuffer_utils.convert_object_to_bytearray(
            initial_model)

        # Define temporary files
        tmp_dir = self.get_temp_dir()
        model_filename = os.path.join(tmp_dir, 'model.tflite')

        # 2. Write model to temporary file (will be used as input for xxd)
        flatbuffer_utils.write_model(initial_model, model_filename)

        # 3. DUMP WITH xxd
        input_cc_file = os.path.join(tmp_dir, 'model.cc')

        command = 'xxd -i {} > {}'.format(model_filename, input_cc_file)
        subprocess.call(command, shell=True)

        # 4. VALIDATE
        final_bytes = flatbuffer_utils.xxd_output_to_bytes(input_cc_file)

        # Validate that the initial and final bytearray are the same
        self.assertEqual(initial_bytes, final_bytes)
Exemple #3
0
def tflite_graph_rewrite(tflite_model_path,
                         saved_model_dir,
                         custom_op_registerers=None):
  """Rewrite TFLite graph to make inputs/outputs tensor name consistent.

  TF users do not have good control over outputs tensor names from
  get_concrete_function(), to maintain backward compatibility the tensor name
  in TFLite graph need to be meaningful and properly set. This function looks up
  the meaningful names from SavedModel signature meta data and rewrite it into
  TFlite graph.

  Arguments:
    tflite_model_path: The path to the exported TFLite graph, which will be
      overwrite after rewrite.
    saved_model_dir: Directory that stores SavedModelthat used for TFLite
    custom_op_registerers: list with custom op registers
      conversion.
  """
  # Find map from signature inputs/outputs name to tensor name in SavedModel.
  meta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir,
                                                        "serve")
  signature_def = meta_graph_def.signature_def
  tensor_name_to_signature_name = {}
  for key, value in signature_def["serving_default"].inputs.items():
    tensor_name_to_signature_name[value.name] = key
  for key, value in signature_def["serving_default"].outputs.items():
    tensor_name_to_signature_name[value.name] = key

  # Find map from TFlite inputs/outputs index to tensor name in TFLite graph.
  with tf.io.gfile.GFile(tflite_model_path, "rb") as f:
    interpreter = interpreter_wrapper.InterpreterWithCustomOps(
        model_content=f.read(),
        custom_op_registerers=custom_op_registerers)
  tflite_input_index_to_tensor_name = {}
  tflite_output_index_to_tensor_name = {}
  for idx, input_detail in enumerate(interpreter.get_input_details()):
    tflite_input_index_to_tensor_name[idx] = input_detail["name"]
  for idx, output_detail in enumerate(interpreter.get_output_details()):
    tflite_output_index_to_tensor_name[idx] = output_detail["name"]

  # Rewrite TFLite graph inputs/outputs name.
  mutable_fb = flatbuffer_utils.read_model_with_mutable_tensors(
      tflite_model_path)
  subgraph = mutable_fb.subgraphs[0]
  for input_idx, input_tensor_name in tflite_input_index_to_tensor_name.items():
    subgraph.tensors[subgraph.inputs[
        input_idx]].name = tensor_name_to_signature_name[input_tensor_name]
  for output_idx, output_tensor_name in tflite_output_index_to_tensor_name.items(
  ):
    subgraph.tensors[subgraph.outputs[
        output_idx]].name = tensor_name_to_signature_name[output_tensor_name]
  flatbuffer_utils.write_model(mutable_fb, tflite_model_path)
Exemple #4
0
    def testWriteReadModel(self):
        # 1. SETUP
        # Define the initial model
        initial_model = test_utils.build_mock_model_python_object()
        # Define temporary files
        tmp_dir = self.get_temp_dir()
        model_filename = os.path.join(tmp_dir, 'model.tflite')

        # 2. INVOKE
        # Invoke the write_model and read_model functions
        flatbuffer_utils.write_model(initial_model, model_filename)
        final_model = flatbuffer_utils.read_model(model_filename)

        # 3. VALIDATE
        # Validate that the initial and final models are the same
        # Validate the description
        self.assertEqual(initial_model.description, final_model.description)
        # Validate the main subgraph's name, inputs, outputs, operators and tensors
        initial_subgraph = initial_model.subgraphs[0]
        final_subgraph = final_model.subgraphs[0]
        self.assertEqual(initial_subgraph.name, final_subgraph.name)
        for i in range(len(initial_subgraph.inputs)):
            self.assertEqual(initial_subgraph.inputs[i],
                             final_subgraph.inputs[i])
        for i in range(len(initial_subgraph.outputs)):
            self.assertEqual(initial_subgraph.outputs[i],
                             final_subgraph.outputs[i])
        for i in range(len(initial_subgraph.operators)):
            self.assertEqual(initial_subgraph.operators[i].opcodeIndex,
                             final_subgraph.operators[i].opcodeIndex)
        initial_tensors = initial_subgraph.tensors
        final_tensors = final_subgraph.tensors
        for i in range(len(initial_tensors)):
            self.assertEqual(initial_tensors[i].name, final_tensors[i].name)
            self.assertEqual(initial_tensors[i].type, final_tensors[i].type)
            self.assertEqual(initial_tensors[i].buffer,
                             final_tensors[i].buffer)
            for j in range(len(initial_tensors[i].shape)):
                self.assertEqual(initial_tensors[i].shape[j],
                                 final_tensors[i].shape[j])
        # Validate the first valid buffer (index 0 is always None)
        initial_buffer = initial_model.buffers[1].data
        final_buffer = final_model.buffers[1].data
        for i in range(initial_buffer.size):
            self.assertEqual(initial_buffer.data[i], final_buffer.data[i])
def main(_):
    """Application run loop."""
    parser = argparse.ArgumentParser(
        description='Reverses xxd dump from to binary file')
    parser.add_argument('--input_cc_file',
                        type=str,
                        required=True,
                        help='Full path name to the input cc file.')
    parser.add_argument(
        '--output_tflite_file',
        type=str,
        required=True,
        help='Full path name to the stripped output tflite file.')

    args = parser.parse_args()

    # Read the model from xxd output C++ source file
    model = flatbuffer_utils.xxd_output_to_object(args.input_cc_file)
    # Write the model
    flatbuffer_utils.write_model(model, args.output_tflite_file)
Exemple #6
0
def main(_):
    """Application run loop."""
    parser = argparse.ArgumentParser(
        description='Strips all nonessential strings from a tflite file.')
    parser.add_argument('--input_tflite_file',
                        type=str,
                        required=True,
                        help='Full path name to the input tflite file.')
    parser.add_argument(
        '--output_tflite_file',
        type=str,
        required=True,
        help='Full path name to the stripped output tflite file.')
    args = parser.parse_args()

    # Read the model
    model = flatbuffer_utils.read_model(args.input_tflite_file)
    # Invoke the strip tflite file function
    flatbuffer_utils.strip_strings(model)
    # Write the model
    flatbuffer_utils.write_model(model, args.output_tflite_file)
Exemple #7
0
def main(_):
    model = flatbuffer_utils.read_model(FLAGS.input_tflite_file)
    flatbuffer_utils.strip_strings(model)
    flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)
def main(_):
    model = flatbuffer_utils.read_model(FLAGS.input_tflite_file)
    flatbuffer_utils.randomize_weights(model, FLAGS.random_seed)
    flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)
def main(_):
    model = flatbuffer_utils.xxd_output_to_object(FLAGS.input_cc_file)
    flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)