コード例 #1
0
def main(_):
    parser = argparse.ArgumentParser(
        description='Randomize weights in a tflite file.')
    parser.add_argument('--input_tflite_file',
                        type=str,
                        required=True,
                        help='Full path name to the input tflite file.')
    parser.add_argument(
        '--output_tflite_file',
        type=str,
        required=True,
        help='Full path name to the output randomized tflite file.')
    parser.add_argument(
        '--random_seed',
        type=str,
        required=False,
        default=0,
        help='Input to the random number generator. The default value is 0.')
    args = parser.parse_args()

    # Read the model
    input_model = flatbuffer_utils.read_model(args.input_tflite_file)
    # Invoke the randomize weights function
    output_model = flatbuffer_utils.randomize_weights(input_model,
                                                      args.random_seed)
    # Write the model
    flatbuffer_utils.write_model(output_model, args.output_tflite_file)
コード例 #2
0
    def testWriteReadModel(self):
        # 1. SETUP
        # Define the initial model
        initial_model = test_utils.build_mock_model_python_object()
        # Define temporary files
        tmp_dir = self.get_temp_dir()
        model_filename = os.path.join(tmp_dir, 'model.tflite')

        # 2. INVOKE
        # Invoke the write_model and read_model functions
        flatbuffer_utils.write_model(initial_model, model_filename)
        final_model = flatbuffer_utils.read_model(model_filename)

        # 3. VALIDATE
        # Validate that the initial and final models are the same
        # Validate the description
        self.assertEqual(initial_model.description, final_model.description)
        # Validate the main subgraph's name, inputs, outputs, operators and tensors
        initial_subgraph = initial_model.subgraphs[0]
        final_subgraph = final_model.subgraphs[0]
        self.assertEqual(initial_subgraph.name, final_subgraph.name)
        for i in range(len(initial_subgraph.inputs)):
            self.assertEqual(initial_subgraph.inputs[i],
                             final_subgraph.inputs[i])
        for i in range(len(initial_subgraph.outputs)):
            self.assertEqual(initial_subgraph.outputs[i],
                             final_subgraph.outputs[i])
        for i in range(len(initial_subgraph.operators)):
            self.assertEqual(initial_subgraph.operators[i].opcodeIndex,
                             final_subgraph.operators[i].opcodeIndex)
        initial_tensors = initial_subgraph.tensors
        final_tensors = final_subgraph.tensors
        for i in range(len(initial_tensors)):
            self.assertEqual(initial_tensors[i].name, final_tensors[i].name)
            self.assertEqual(initial_tensors[i].type, final_tensors[i].type)
            self.assertEqual(initial_tensors[i].buffer,
                             final_tensors[i].buffer)
            for j in range(len(initial_tensors[i].shape)):
                self.assertEqual(initial_tensors[i].shape[j],
                                 final_tensors[i].shape[j])
        # Validate the first valid buffer (index 0 is always None)
        initial_buffer = initial_model.buffers[1].data
        final_buffer = final_model.buffers[1].data
        for i in range(initial_buffer.size):
            self.assertEqual(initial_buffer.data[i], final_buffer.data[i])
コード例 #3
0
ファイル: strip_strings.py プロジェクト: zyl1984/tensorflow
def main(_):
    """Application run loop."""
    parser = argparse.ArgumentParser(
        description='Strips all nonessential strings from a tflite file.')
    parser.add_argument('--input_tflite_file',
                        type=str,
                        required=True,
                        help='Full path name to the input tflite file.')
    parser.add_argument(
        '--output_tflite_file',
        type=str,
        required=True,
        help='Full path name to the stripped output tflite file.')
    args = parser.parse_args()

    # Read the model
    model = flatbuffer_utils.read_model(args.input_tflite_file)
    # Invoke the strip tflite file function
    flatbuffer_utils.strip_strings(model)
    # Write the model
    flatbuffer_utils.write_model(model, args.output_tflite_file)
コード例 #4
0
def main(_):
    model = flatbuffer_utils.read_model(FLAGS.input_tflite_file)
    flatbuffer_utils.strip_strings(model)
    flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)
コード例 #5
0
def main(_):
    model = flatbuffer_utils.read_model(FLAGS.input_tflite_file)
    flatbuffer_utils.randomize_weights(model, FLAGS.random_seed)
    flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)