示例#1
0
def main(argv):
    del argv  # unused

    embeddings_path = FLAGS.embeddings_path

    preprocessor = text_preprocessor.TextPreprocessor(embeddings_path)

    nltk.download("punkt")
    train_preprocess_fn = preprocessor.train_preprocess_fn(nltk.word_tokenize)
    dataset = tfrecord_input.TFRecordInputWithTokenizer(
        train_preprocess_fn=train_preprocess_fn)

    # TODO: Move embedding *into* Keras model.
    model_tf = tf_gru_attention_multiclass.TFRNNModel(dataset.labels())
    model = preprocessor.add_embedding_to_model(model_tf,
                                                base_model.TOKENS_FEATURE_KEY)

    trainer = model_trainer.ModelTrainer(dataset, model)
    trainer.train_with_eval()

    serving_input_fn = serving_input.create_serving_input_fn(
        word_to_idx=preprocessor._word_to_idx,
        unknown_token=preprocessor._unknown_token,
        text_feature_name=base_model.TOKENS_FEATURE_KEY,
        example_key_name=base_model.EXAMPLE_KEY)
    trainer.export(serving_input_fn, base_model.EXAMPLE_KEY)
示例#2
0
def main(argv):
  del argv  # unused

  embeddings_path = FLAGS.embeddings_path
  text_feature_name = FLAGS.text_feature_name
  key_name = FLAGS.key_name

  preprocessor = text_preprocessor.TextPreprocessor(embeddings_path)

  nltk.download("punkt")
  train_preprocess_fn = preprocessor.train_preprocess_fn(nltk.word_tokenize)
  dataset = tfrecord_input.TFRecordInput(
      train_path=FLAGS.train_path,
      validate_path=FLAGS.validate_path,
      text_feature=text_feature_name,
      labels=LABELS,
      train_preprocess_fn=train_preprocess_fn,
      batch_size=FLAGS.batch_size)

  # TODO: Move embedding *into* Keras model.
  model_keras = keras_gru_attention.KerasRNNModel(
    set(LABELS.keys()),
    preprocessor._embedding_size)
  model = preprocessor.add_embedding_to_model(
      model_keras, text_feature_name)

  trainer = model_trainer.ModelTrainer(dataset, model)
  trainer.train_with_eval(FLAGS.train_steps, FLAGS.eval_period, FLAGS.eval_steps)

  serving_input_fn = serving_input.create_serving_input_fn(
      word_to_idx=preprocessor._word_to_idx,
      unknown_token=preprocessor._unknown_token,
      text_feature_name=text_feature_name,
      key_name=key_name)
  trainer.export(serving_input_fn)