def testModelFnExportModeExecute(self): """Verifies model can be exported to savedmodel and tflite model.""" self.params['encoder_type'] = FLAGS.encoder_type self.params['num_predictions'] = FLAGS.num_predictions train_input_fn = launcher.InputFn(FLAGS.training_data_filepattern, FLAGS.batch_size) eval_input_fn = launcher.InputFn(FLAGS.testing_data_filepattern, FLAGS.batch_size) model = launcher.build_keras_model(self.params, FLAGS.learning_rate, FLAGS.gradient_clip_norm) launcher.train_and_eval(model=model, model_dir=FLAGS.model_dir, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, steps_per_epoch=2, epochs=2, eval_steps=1) export_dir = os.path.join(FLAGS.model_dir, 'export') latest_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) launcher.export(checkpoint_path=latest_checkpoint, export_dir=export_dir, params=self.params, max_history_length=FLAGS.max_history_length) savedmodel_path = os.path.join(export_dir, 'saved_model.pb') self.assertTrue(os.path.exists(savedmodel_path)) imported = tf.saved_model.load(export_dir, tags=None) infer = imported.signatures['serving_default'] context = tf.range(10) predictions = infer(context) self.assertAllEqual([10], predictions['top_prediction_ids'].shape) self.assertAllEqual([10], predictions['top_prediction_scores'].shape) launcher.export_tflite(export_dir) tflite_model_path = os.path.join(export_dir, 'model.tflite') self.assertTrue(os.path.exists(tflite_model_path)) f = open(tflite_model_path, 'rb') interpreter = tf.lite.Interpreter(model_content=f.read()) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() self.assertEqual([10], input_details[0]['shape']) self.assertEqual('serving_default_context:0', input_details[0]['name']) interpreter.set_tensor(input_details[0]['index'], context) interpreter.invoke() tflite_top_predictions_ids = interpreter.get_tensor( output_details[0]['index']) tflite_top_prediction_scores = interpreter.get_tensor( output_details[1]['index']) self.assertAllEqual([10], tflite_top_predictions_ids.shape) self.assertAllEqual([10], tflite_top_prediction_scores.shape)
def testModelFnTrainModeExecute(self): """Verifies that 'model_fn' can be executed in train and eval mode.""" self.params['encoder_type'] = FLAGS.encoder_type train_input_fn = launcher.InputFn(FLAGS.training_data_filepattern, FLAGS.batch_size) eval_input_fn = launcher.InputFn(FLAGS.testing_data_filepattern, FLAGS.batch_size) model = launcher.build_keras_model(self.params, FLAGS.learning_rate, FLAGS.gradient_clip_norm) launcher.train_and_eval(model=model, model_dir=FLAGS.model_dir, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, steps_per_epoch=2, epochs=2, eval_steps=1) self.assertTrue(os.path.exists(self.test_model_dir)) summaries_dir = os.path.join(self.test_model_dir, 'summaries') self.assertTrue(os.path.exists(summaries_dir))