def main(_):
  parser = argparse.ArgumentParser(
      description='Randomize weights in a tflite file.')
  parser.add_argument(
      '--input_tflite_file',
      type=str,
      required=True,
      help='Full path name to the input tflite file.')
  parser.add_argument(
      '--output_tflite_file',
      type=str,
      required=True,
      help='Full path name to the output randomized tflite file.')
  parser.add_argument(
      '--random_seed',
      type=str,
      required=False,
      default=0,
      help='Input to the random number generator. The default value is 0.')
  args = parser.parse_args()

  # Read the model
  model = flatbuffer_utils.read_model(args.input_tflite_file)
  # Invoke the randomize weights function
  flatbuffer_utils.randomize_weights(model, args.random_seed)
  # Write the model
  flatbuffer_utils.write_model(model, args.output_tflite_file)
    def testRandomizeSomeWeights(self):
        # 1. SETUP
        # Define the initial model
        initial_model = test_utils.build_mock_model()
        final_model = copy.deepcopy(initial_model)

        # 2. INVOKE
        # Invoke the randomize_weights function, but skip the first buffer
        flatbuffer_utils.randomize_weights(
            final_model, buffers_to_skip=[_SKIPPED_BUFFER_INDEX])

        # 3. VALIDATE
        # Validate that the initial and final models are the same, except that
        # the weights in the model buffer have been modified (i.e, randomized)
        # Validate the description
        self.assertEqual(initial_model.description, final_model.description)
        # Validate the main subgraph's name, inputs, outputs, operators and tensors
        initial_subgraph = initial_model.subgraphs[0]
        final_subgraph = final_model.subgraphs[0]
        self.assertEqual(initial_subgraph.name, final_subgraph.name)
        for i, _ in enumerate(initial_subgraph.inputs):
            self.assertEqual(initial_subgraph.inputs[i],
                             final_subgraph.inputs[i])
        for i, _ in enumerate(initial_subgraph.outputs):
            self.assertEqual(initial_subgraph.outputs[i],
                             final_subgraph.outputs[i])
        for i, _ in enumerate(initial_subgraph.operators):
            self.assertEqual(initial_subgraph.operators[i].opcodeIndex,
                             final_subgraph.operators[i].opcodeIndex)
        initial_tensors = initial_subgraph.tensors
        final_tensors = final_subgraph.tensors
        for i, _ in enumerate(initial_tensors):
            self.assertEqual(initial_tensors[i].name, final_tensors[i].name)
            self.assertEqual(initial_tensors[i].type, final_tensors[i].type)
            self.assertEqual(initial_tensors[i].buffer,
                             final_tensors[i].buffer)
            for j in range(len(initial_tensors[i].shape)):
                self.assertEqual(initial_tensors[i].shape[j],
                                 final_tensors[i].shape[j])
        # Validate that the skipped buffer is unchanged.
        initial_buffer = initial_model.buffers[_SKIPPED_BUFFER_INDEX].data
        final_buffer = final_model.buffers[_SKIPPED_BUFFER_INDEX].data
        for j in range(initial_buffer.size):
            self.assertEqual(initial_buffer.data[j], final_buffer.data[j])
    def testRandomizeWeights(self):
        # 1. SETUP
        # Define the initial model
        initial_model = test_utils.build_mock_model_python_object()
        final_model = copy.deepcopy(initial_model)

        # 2. INVOKE
        # Invoke the randomize_weights function
        flatbuffer_utils.randomize_weights(final_model)

        # 3. VALIDATE
        # Validate that the initial and final models are the same, except that
        # the weights in the model buffer have been modified (i.e, randomized)
        # Validate the description
        self.assertEqual(initial_model.description, final_model.description)
        # Validate the main subgraph's name, inputs, outputs, operators and tensors
        initial_subgraph = initial_model.subgraphs[0]
        final_subgraph = final_model.subgraphs[0]
        self.assertEqual(initial_subgraph.name, final_subgraph.name)
        for i in range(len(initial_subgraph.inputs)):
            self.assertEqual(initial_subgraph.inputs[i],
                             final_subgraph.inputs[i])
        for i in range(len(initial_subgraph.outputs)):
            self.assertEqual(initial_subgraph.outputs[i],
                             final_subgraph.outputs[i])
        for i in range(len(initial_subgraph.operators)):
            self.assertEqual(initial_subgraph.operators[i].opcodeIndex,
                             final_subgraph.operators[i].opcodeIndex)
        initial_tensors = initial_subgraph.tensors
        final_tensors = final_subgraph.tensors
        for i in range(len(initial_tensors)):
            self.assertEqual(initial_tensors[i].name, final_tensors[i].name)
            self.assertEqual(initial_tensors[i].type, final_tensors[i].type)
            self.assertEqual(initial_tensors[i].buffer,
                             final_tensors[i].buffer)
            for j in range(len(initial_tensors[i].shape)):
                self.assertEqual(initial_tensors[i].shape[j],
                                 final_tensors[i].shape[j])
        # Validate the first valid buffer (index 0 is always None)
        initial_buffer = initial_model.buffers[1].data
        final_buffer = final_model.buffers[1].data
        for j in range(initial_buffer.size):
            self.assertNotEqual(initial_buffer.data[j], final_buffer.data[j])
def main(_):
    model = flatbuffer_utils.read_model(FLAGS.input_tflite_file)
    flatbuffer_utils.randomize_weights(model, FLAGS.random_seed)
    flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)