def main(argv): del argv # unused embeddings_path = FLAGS.embeddings_path preprocessor = text_preprocessor.TextPreprocessor(embeddings_path) nltk.download("punkt") train_preprocess_fn = preprocessor.train_preprocess_fn(nltk.word_tokenize) dataset = tfrecord_input.TFRecordInputWithTokenizer( train_preprocess_fn=train_preprocess_fn) # TODO: Move embedding *into* Keras model. model_tf = tf_gru_attention_multiclass.TFRNNModel(dataset.labels()) model = preprocessor.add_embedding_to_model(model_tf, base_model.TOKENS_FEATURE_KEY) trainer = model_trainer.ModelTrainer(dataset, model) trainer.train_with_eval() serving_input_fn = serving_input.create_serving_input_fn( word_to_idx=preprocessor._word_to_idx, unknown_token=preprocessor._unknown_token, text_feature_name=base_model.TOKENS_FEATURE_KEY, example_key_name=base_model.EXAMPLE_KEY) trainer.export(serving_input_fn, base_model.EXAMPLE_KEY)
def main(argv): del argv # unused embeddings_path = FLAGS.embeddings_path text_feature_name = FLAGS.text_feature_name key_name = FLAGS.key_name preprocessor = text_preprocessor.TextPreprocessor(embeddings_path) nltk.download("punkt") train_preprocess_fn = preprocessor.train_preprocess_fn(nltk.word_tokenize) dataset = tfrecord_input.TFRecordInput( train_path=FLAGS.train_path, validate_path=FLAGS.validate_path, text_feature=text_feature_name, labels=LABELS, train_preprocess_fn=train_preprocess_fn, batch_size=FLAGS.batch_size) # TODO: Move embedding *into* Keras model. model_keras = keras_gru_attention.KerasRNNModel( set(LABELS.keys()), preprocessor._embedding_size) model = preprocessor.add_embedding_to_model( model_keras, text_feature_name) trainer = model_trainer.ModelTrainer(dataset, model) trainer.train_with_eval(FLAGS.train_steps, FLAGS.eval_period, FLAGS.eval_steps) serving_input_fn = serving_input.create_serving_input_fn( word_to_idx=preprocessor._word_to_idx, unknown_token=preprocessor._unknown_token, text_feature_name=text_feature_name, key_name=key_name) trainer.export(serving_input_fn)