示例#1
0
 def testModelFnExportModeExecute(self):
     """Verifies model can be exported to savedmodel and tflite model."""
     self.params['encoder_type'] = FLAGS.encoder_type
     self.params['num_predictions'] = FLAGS.num_predictions
     train_input_fn = launcher.InputFn(FLAGS.training_data_filepattern,
                                       FLAGS.batch_size)
     eval_input_fn = launcher.InputFn(FLAGS.testing_data_filepattern,
                                      FLAGS.batch_size)
     model = launcher.build_keras_model(self.params, FLAGS.learning_rate,
                                        FLAGS.gradient_clip_norm)
     launcher.train_and_eval(model=model,
                             model_dir=FLAGS.model_dir,
                             train_input_fn=train_input_fn,
                             eval_input_fn=eval_input_fn,
                             steps_per_epoch=2,
                             epochs=2,
                             eval_steps=1)
     export_dir = os.path.join(FLAGS.model_dir, 'export')
     latest_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
     launcher.export(checkpoint_path=latest_checkpoint,
                     export_dir=export_dir,
                     params=self.params,
                     max_history_length=FLAGS.max_history_length)
     savedmodel_path = os.path.join(export_dir, 'saved_model.pb')
     self.assertTrue(os.path.exists(savedmodel_path))
     imported = tf.saved_model.load(export_dir, tags=None)
     infer = imported.signatures['serving_default']
     context = tf.range(10)
     predictions = infer(context)
     self.assertAllEqual([10], predictions['top_prediction_ids'].shape)
     self.assertAllEqual([10], predictions['top_prediction_scores'].shape)
     launcher.export_tflite(export_dir)
     tflite_model_path = os.path.join(export_dir, 'model.tflite')
     self.assertTrue(os.path.exists(tflite_model_path))
     f = open(tflite_model_path, 'rb')
     interpreter = tf.lite.Interpreter(model_content=f.read())
     interpreter.allocate_tensors()
     input_details = interpreter.get_input_details()
     output_details = interpreter.get_output_details()
     self.assertEqual([10], input_details[0]['shape'])
     self.assertEqual('serving_default_context:0', input_details[0]['name'])
     interpreter.set_tensor(input_details[0]['index'], context)
     interpreter.invoke()
     tflite_top_predictions_ids = interpreter.get_tensor(
         output_details[0]['index'])
     tflite_top_prediction_scores = interpreter.get_tensor(
         output_details[1]['index'])
     self.assertAllEqual([10], tflite_top_predictions_ids.shape)
     self.assertAllEqual([10], tflite_top_prediction_scores.shape)
示例#2
0
 def testModelFnTrainModeExecute(self):
     """Verifies that 'model_fn' can be executed in train and eval mode."""
     self.params['encoder_type'] = FLAGS.encoder_type
     train_input_fn = launcher.InputFn(FLAGS.training_data_filepattern,
                                       FLAGS.batch_size)
     eval_input_fn = launcher.InputFn(FLAGS.testing_data_filepattern,
                                      FLAGS.batch_size)
     model = launcher.build_keras_model(self.params, FLAGS.learning_rate,
                                        FLAGS.gradient_clip_norm)
     launcher.train_and_eval(model=model,
                             model_dir=FLAGS.model_dir,
                             train_input_fn=train_input_fn,
                             eval_input_fn=eval_input_fn,
                             steps_per_epoch=2,
                             epochs=2,
                             eval_steps=1)
     self.assertTrue(os.path.exists(self.test_model_dir))
     summaries_dir = os.path.join(self.test_model_dir, 'summaries')
     self.assertTrue(os.path.exists(summaries_dir))