Example #1
0
def main(_):
    #tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.set_verbosity(tf.logging.DEBUG)
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_id
    processor = SenpairProcessor()

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_predict' must be True."
        )

    config = BaseConfig.from_json_file(FLAGS.config_file)
    tf.gfile.MakeDirs(FLAGS.output_dir)

    tokenizer = tokenization.Tokenizer(vocab_file=FLAGS.vocab_file,
                                       stop_words_file=FLAGS.stop_words_file,
                                       use_pos=False)

    run_config = None
    num_train_steps = 0
    num_warmup_steps = 0
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.input_file)
        num_train_steps = int(
            len(train_examples) / FLAGS.batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = FLAGS.num_warmup_steps

        run_config = tf.estimator.RunConfig(
            save_summary_steps=100,
            save_checkpoints_steps=num_train_steps / FLAGS.num_train_epochs,
            keep_checkpoint_max=5,
        )

    embedding_table = None
    if FLAGS.embedding_table is not None:
        embedding_table = load_embedding_table(FLAGS.embedding_table,
                                               FLAGS.vocab_file)

    model_fn = model_fn_builder(config=config,
                                learning_rate=FLAGS.learning_rate,
                                task=FLAGS.task_type,
                                single_text=FLAGS.single_text,
                                init_checkpoint=FLAGS.init_checkpoint,
                                num_train_steps=num_train_steps,
                                num_warmup_steps=num_warmup_steps,
                                embedding_table_value=embedding_table,
                                embedding_table_trainable=False,
                                model_name=FLAGS.model_name)

    params = {"batch_size": FLAGS.batch_size}
    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=FLAGS.output_dir,
                                       config=run_config,
                                       params=params)

    if FLAGS.do_train:
        if FLAGS.cached_tfrecord:
            train_file = FLAGS.cached_tfrecord
        else:
            train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        if not os.path.exists(train_file):
            file_based_convert_examples_to_features(train_examples,
                                                    FLAGS.max_seq_length,
                                                    tokenizer,
                                                    train_file,
                                                    do_token=FLAGS.do_token)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Batch size = %d", FLAGS.batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        del train_examples  # 释放train_examples内存
        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
        )

        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
    elif FLAGS.do_eval:
        dev_examples = processor.get_train_examples(FLAGS.input_file)
        if FLAGS.cached_tfrecord:
            dev_file = FLAGS.cached_tfrecord
        else:
            dev_file = os.path.join(FLAGS.output_dir, "dev.tf_record")
        if not os.path.exists(dev_file):
            file_based_convert_examples_to_features(dev_examples,
                                                    FLAGS.max_seq_length,
                                                    tokenizer, dev_file)
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d", len(dev_examples))
        tf.logging.info("  Batch size = %d", FLAGS.batch_size)
        del dev_examples
        eval_input_fn = file_based_input_fn_builder(
            input_file=dev_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False)

        if FLAGS.eval_model is not None:
            eval_model_path = os.path.join(FLAGS.output_dir, FLAGS.eval_model)
        else:
            eval_model_path = None

        result = estimator.evaluate(input_fn=eval_input_fn,
                                    checkpoint_path=eval_model_path)
        eval_output_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(eval_output_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
    else:
        predict_examples = processor.get_test_examples(FLAGS.input_file)
        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        file_based_convert_examples_to_features(predict_examples,
                                                FLAGS.max_seq_length,
                                                tokenizer,
                                                predict_file,
                                                set_type="test",
                                                label_type="int",
                                                single_text=FLAGS.single_text)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d", len(predict_examples))
        tf.logging.info("  Batch size = %d", FLAGS.batch_size)
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            single_text=FLAGS.single_text)

        if FLAGS.pred_model is not None:
            pred_model_path = os.path.join(FLAGS.output_dir, FLAGS.pred_model)
        else:
            pred_model_path = None

        result = estimator.predict(input_fn=predict_input_fn,
                                   checkpoint_path=pred_model_path)

        output_predict_file = os.path.join(FLAGS.output_dir,
                                           "test_results.tsv")
        with tf.gfile.GFile(output_predict_file, "w") as writer:
            tf.logging.info("***** Predict results *****")
            for (i, prediction) in enumerate(result):
                sen_a_embedding = prediction["sen_a_embedding"]
                input_ids_a = prediction["input_ids_a"]
                keyword_probs_a = prediction["keyword_probs_a"]
                if not FLAGS.single_text:
                    sen_b_embedding = prediction["sen_b_embedding"]
                    input_ids_b = prediction["input_ids_b"]
                    keyword_probs_b = prediction["keyword_probs_b"]

                sorted_keyword_idx_a = np.argsort(-keyword_probs_a)
                extracted_keywords_a = []
                for idx in sorted_keyword_idx_a:
                    word_id = input_ids_a[idx]
                    word_prob = keyword_probs_a[idx]
                    word = tokenizer.convert_ids_to_tokens([word_id])[0]
                    extracted_keywords_a.append([word, word_prob])

                keyword_output_a = " ".join([
                    "%s:%f" % (kw, prob) for kw, prob in extracted_keywords_a
                ])
                text_output_a = " ".join(
                    tokenizer.convert_ids_to_tokens(input_ids_a))

                writer.write("%s\t%s" % (keyword_output_a, text_output_a))
                writer.write("\n")
def main(_):
    #tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.set_verbosity(tf.logging.DEBUG)
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_id

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError("At least one of `do_train`, `do_eval` or `do_predict' must be True.")

    config = BaseConfig.from_json_file(FLAGS.config_file)
    tf.gfile.MakeDirs(FLAGS.output_dir)

    tokenizer = tokenization.Tokenizer(vocab_file=FLAGS.vocab_file, stop_words_file=FLAGS.stop_words_file, use_pos=False)

    run_config = None
    if FLAGS.do_train:
        num_train_steps = FLAGS.num_train_steps
        num_warmup_steps = FLAGS.num_warmup_steps
        run_config = tf.estimator.RunConfig(
            save_summary_steps=100,
            save_checkpoints_steps=num_train_steps/FLAGS.num_train_epochs,
            keep_checkpoint_max=5,
        )

    embedding_table = None
    if FLAGS.embedding_table is not None:
        embedding_table = load_embedding_table(FLAGS.embedding_table, FLAGS.vocab_file)

    model_fn = model_fn_builder(config=config,
                                learning_rate=FLAGS.learning_rate,
                                task=FLAGS.task_type,
                                init_checkpoint=FLAGS.init_checkpoint,
                                num_train_steps=num_train_steps,
                                num_warmup_steps=num_warmup_steps,
                                embedding_table_value=None,
                                embedding_table_trainable=False,
                                model_name=FLAGS.model_name)


    params = {"batch_size": FLAGS.batch_size}
    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=FLAGS.output_dir,
                                       config=run_config,
                                       params=params)


    if FLAGS.do_train:
        train_file = FLAGS.input_file
        tf.logging.info("***** Running training *****")
        train_input_fn = pretrain_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            mask_num=FLAGS.mask_num,
            is_training=True
        )
        estimator.train(input_fn=train_input_fn,
                        max_steps=num_train_steps)
    elif FLAGS.do_eval:
        dev_file = FLAGS.input_file
        eval_input_fn = pretrain_input_fn_builder(
            input_file=dev_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False
        )

        if FLAGS.eval_model is not None:
            eval_model_path = os.path.join(FLAGS.output_dir, FLAGS.eval_model)
        else:
            eval_model_path = None

        result = estimator.evaluate(
            input_fn=eval_input_fn,
            checkpoint_path=eval_model_path
        )
        eval_output_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(eval_output_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))