def main(_):
    logger = tf.get_logger()
    if not tf.io.gfile.exists(FLAGS.model_dir):
        tf.io.gfile.mkdir(FLAGS.model_dir)

    if not tf.io.gfile.exists(FLAGS.export_dir):
        tf.io.gfile.mkdir(FLAGS.export_dir)

    input_config = load_input_config()
    model_config = prepare_model_config()

    logger.info('Setting up train and eval input datasets.')
    train_input_dataset = input_pipeline.get_input_dataset(
        data_filepattern=FLAGS.training_data_filepattern,
        input_config=input_config,
        vocab_file_dir=FLAGS.vocab_dir,
        batch_size=FLAGS.batch_size)
    eval_input_dataset = input_pipeline.get_input_dataset(
        data_filepattern=FLAGS.testing_data_filepattern,
        input_config=input_config,
        vocab_file_dir=FLAGS.vocab_dir,
        batch_size=FLAGS.batch_size)

    logger.info('Build keras model for mode: {}.'.format(FLAGS.run_mode))
    model = build_keras_model(input_config=input_config,
                              model_config=model_config)

    if FLAGS.run_mode == 'train_and_eval':
        train_and_eval(model=model,
                       model_dir=FLAGS.model_dir,
                       train_input_dataset=train_input_dataset,
                       eval_input_dataset=eval_input_dataset,
                       steps_per_epoch=FLAGS.steps_per_epoch,
                       epochs=FLAGS.num_epochs,
                       eval_steps=FLAGS.num_eval_steps)
        latest_checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
        if latest_checkpoint_path:
            export(checkpoint_path=latest_checkpoint_path,
                   input_config=input_config,
                   model_config=model_config,
                   export_dir=FLAGS.export_dir)
    elif FLAGS.run_mode == 'export':
        checkpoint_path = (FLAGS.checkpoint_path if FLAGS.checkpoint_path else
                           tf.train.latest_checkpoint(FLAGS.model_dir))
        export(checkpoint_path=checkpoint_path,
               input_config=input_config,
               model_config=model_config,
               export_dir=FLAGS.export_dir)
    else:
        logger.error('Unsupported launcher run model {}.'.format(
            FLAGS.run_mode))
 def test_get_input_dataset(self):
     dataset = input_pipeline.get_input_dataset(
         data_filepattern=self.test_input_data_file,
         input_config=TEST_INPUT_CONFIG,
         vocab_file_dir=self.tmp_dir,
         batch_size=1)
     dataset = dataset.take(1)
     self.assertCountEqual([
         'context_movie_id', 'context_movie_rating', 'context_movie_genre',
         'label_movie_id'
     ], dataset.element_spec[0].keys())
 def testModelTrainEvalExport(self):
     """Verifies that model can be trained and evaluated."""
     tf.io.gfile.mkdir(FLAGS.model_dir)
     input_config = launcher.load_input_config()
     model_config = launcher.prepare_model_config()
     dataset = input_pipeline.get_input_dataset(
         data_filepattern=self.test_input_data_file,
         input_config=input_config,
         vocab_file_dir=self.tmp_dir,
         batch_size=8)
     model = launcher.build_keras_model(input_config, model_config)
     launcher.train_and_eval(model=model,
                             model_dir=FLAGS.model_dir,
                             train_input_dataset=dataset,
                             eval_input_dataset=dataset,
                             steps_per_epoch=2,
                             epochs=2,
                             eval_steps=1)
     self.assertTrue(os.path.exists(self.test_model_dir))
     summaries_dir = os.path.join(self.test_model_dir, 'summaries')
     self.assertTrue(os.path.exists(summaries_dir))
     export_dir = os.path.join(FLAGS.model_dir, 'export')
     latest_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
     launcher.save_model(checkpoint_path=latest_checkpoint,
                         export_dir=export_dir,
                         input_config=input_config,
                         model_config=model_config)
     savedmodel_path = os.path.join(export_dir, 'saved_model.pb')
     self.assertTrue(os.path.exists(savedmodel_path))
     imported = tf.saved_model.load(export_dir, tags=None)
     infer = imported.signatures['serving_default']
     context_movie_id = tf.range(5, dtype=tf.int32)
     context_movie_rating = tf.range(5, dtype=tf.float32)
     context_movie_genre = tf.range(8, dtype=tf.int32)
     predictions = infer(context_movie_id=context_movie_id,
                         context_movie_rating=context_movie_rating,
                         context_movie_genre=context_movie_genre)
     self.assertAllEqual([5], predictions['top_prediction_ids'].shape)
     self.assertAllEqual([5], predictions['top_prediction_scores'].shape)
     launcher.export_tflite(export_dir)
     tflite_model_path = os.path.join(export_dir, 'model.tflite')
     self.assertTrue(os.path.exists(tflite_model_path))
     f = open(tflite_model_path, 'rb')
     interpreter = tf.lite.Interpreter(model_content=f.read())
     interpreter.allocate_tensors()
     inference_signature = interpreter.get_signature_list(
     )['serving_default']
     self.assertAllEqual([
         'context_movie_genre', 'context_movie_id', 'context_movie_rating'
     ], inference_signature['inputs'])
     self.assertAllEqual(['top_prediction_ids', 'top_prediction_scores'],
                         inference_signature['outputs'])
     serving_name_to_tenors = {
         'serving_default_context_movie_id:0': context_movie_id,
         'serving_default_context_movie_rating:0': context_movie_rating,
         'serving_default_context_movie_genre:0': context_movie_genre
     }
     input_details = interpreter.get_input_details()
     output_details = interpreter.get_output_details()
     indice_to_tensors = {}
     for input_detail in input_details:
         indice_to_tensors[input_detail['index']] = serving_name_to_tenors[
             input_detail['name']]
     for index, tensor in indice_to_tensors.items():
         interpreter.set_tensor(index, tensor)
     interpreter.invoke()
     tflite_top_predictions_ids = interpreter.get_tensor(
         output_details[0]['index'])
     tflite_top_prediction_scores = interpreter.get_tensor(
         output_details[1]['index'])
     self.assertAllEqual([5], tflite_top_predictions_ids.shape)
     self.assertAllEqual([5], tflite_top_prediction_scores.shape)