#(total_loss, logits, trans, pred_ids) = \
    (total_loss, per_example_loss, logits,
     probabilities) = create_model(bert_config=bert_config,
                                   is_training=False,
                                   input_ids=input_ids_p,
                                   input_mask=input_mask_p,
                                   segment_ids=None,
                                   labels=None,
                                   num_labels=num_labels,
                                   use_one_hot_embeddings=False)

    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint(model_dir))

tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(
    bert_dir, 'vocab.txt'),
                                       do_lower_case=True)


@app.route('/class_predict_service', methods=['GET', 'POST'])
def class_predict_service():
    """
    do online prediction. each time make prediction for one instance.
    you can change to a batch if you want.
    :param line: a list. element is: [dummy_label,text_a,text_b]
    :return:
    """
    def convert(line):
        feature = convert_single_example(0, line, label_list, max_seq_length,
                                         tokenizer)
        input_ids = np.reshape([feature.input_ids],
                           encoding='utf-8')
    text_f = open(os.path.join(os.path.join(DATA_OUTPUT_DIR, file_set_type),
                               "text.txt"),
                  "w",
                  encoding='utf-8')
    token_in_f = open(os.path.join(
        os.path.join(DATA_OUTPUT_DIR, file_set_type), "token_in.txt"),
                      "w",
                      encoding='utf-8')
    token_in_not_UNK_f = open(os.path.join(
        os.path.join(DATA_OUTPUT_DIR, file_set_type), "token_in_not_UNK.txt"),
                              "w",
                              encoding='utf-8')

    # Processing
    bert_tokenizer = tokenization.FullTokenizer(
        vocab_file=vocab_file, do_lower_case=True)  # 初始化 bert_token 工具
    # feature
    text = '\n'.join(df_data.item)
    text_tokened = df_data.item.apply(bert_tokenizer.tokenize)
    text_tokened = '\n'.join([' '.join(row) for row in text_tokened])
    text_tokened_not_UNK = df_data.item.apply(bert_tokenizer.tokenize_not_UNK)
    text_tokened_not_UNK = '\n'.join(
        [' '.join(row) for row in text_tokened_not_UNK])
    # label only choose first 3 lables: 高中 学科 一级知识点
    # if you want all labels
    # just remove list slice
    predicate_list = df_data.labels.apply(lambda x: x.split())
    predicate_list_str = '\n'.join([' '.join(row) for row in predicate_list])

    print(f'datasize: {len(df_data)}')
    text_f.write(text)
def predict_bert(m_type, data_path, train_data_file, test_data_file):

    parameter_setting(m_type=m_type,
                      data_path=data_path,
                      train_file=train_data_file,
                      test_file=test_data_file)

    checkpoint_file = tf.train.latest_checkpoint(FLAGS.model_output_dir)

    # FLAGS.model_output_dir = None

    precessor = DomainClf_Processor()
    label_list = precessor.get_labels()

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=None,
        master=FLAGS.master,
        model_dir=FLAGS.model_output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host),
        session_config=config)

    model_fn = model_fn_builder(bert_config=bert_config,
                                num_labels=len(label_list),
                                init_checkpoint=checkpoint_file,
                                learning_rate=FLAGS.learning_rate,
                                num_train_steps=10000,
                                num_warmup_steps=1000,
                                use_tpu=FLAGS.use_tpu,
                                use_one_hot_embeddings=FLAGS.use_tpu)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    predict_examples = precessor.get_test_examples(FLAGS.data_dir)
    num_actual_predict_examples = len(predict_examples)

    predict_file = os.path.join(
        FLAGS.model_output_dir,
        "predict.tf_record")  # overwrite the latest file
    file_based_convert_examples_to_features(predict_examples, label_list,
                                            FLAGS.max_seq_length, tokenizer,
                                            predict_file)

    tf.logging.info("***** Running prediction*****")
    tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                    len(predict_examples), num_actual_predict_examples,
                    len(predict_examples) - num_actual_predict_examples)
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    predict_drop_remainder = True if FLAGS.use_tpu else False
    predict_input_fn = file_based_input_fn_builder(
        input_file=predict_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=predict_drop_remainder)

    result = estimator.predict(input_fn=predict_input_fn)

    output_predict_file = os.path.join(FLAGS.model_output_dir,
                                       FLAGS.test_file.split('.')[0] + ".tsv")
    output_label_file = os.path.join(
        FLAGS.model_output_dir,
        FLAGS.test_file.split('.')[0] + "_labels.tsv")

    f = open(output_label_file, "w")
    s = set()
    result_list = []
    with tf.gfile.GFile(output_predict_file, "w") as writer:
        num_written_lines = 0
        tf.logging.info("***** Predict results *****")
        for (i, prediction) in enumerate(result):
            probabilities = prediction["probabilities"]
            assert len(probabilities) == len(label_list)
            if i >= num_actual_predict_examples:
                break
            output_line = "\t".join(
                str(class_probability)
                for class_probability in probabilities) + "\n"

            result_list.append(output_line.strip())

            writer.write(output_line)
            num_written_lines += 1

            # labels
            lbl_id = np.argmax(np.asarray(probabilities))
            f.write(label_list[lbl_id] + "\n")
            s.update([label_list[lbl_id]])

    write_list_to_file(
        os.path.join(base_dir, 'result/stacking/bert',
                     str(FLAGS.train_file.split('.')[0]),
                     FLAGS.test_file.split('.')[0] + '.tsv'), result_list)
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "baidu_95": Baidu_95_Multi_Label_Classification_Processor,
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_predict' must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.gfile.MakeDirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)

    processor = processors[task_name]()

    label_list = processor.get_labels()
    label_length = len(label_list)

    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    model_fn = model_fn_builder(
        bert_config=bert_config,
        num_labels=len(label_list),
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        file_based_convert_examples_to_features(
            train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            label_length=label_length,
            is_training=True,
            drop_remainder=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_eval:
        eval_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_eval_examples = len(eval_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on. These do NOT count towards the metric (all tf.metrics
            # support a per-instance weight, and these get a weight of 0.0).
            while len(eval_examples) % FLAGS.eval_batch_size != 0:
                eval_examples.append(PaddingInputExample())

        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        file_based_convert_examples_to_features(
            eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)

        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(eval_examples), num_actual_eval_examples,
                        len(eval_examples) - num_actual_eval_examples)
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the
        # number of steps.
        if FLAGS.use_tpu:
            assert len(eval_examples) % FLAGS.eval_batch_size == 0
            eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

        eval_drop_remainder = True if FLAGS.use_tpu else False
        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            seq_length=FLAGS.max_seq_length,
            label_length=label_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)

        result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    if FLAGS.do_predict:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_predict_examples = len(predict_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on.
            while len(predict_examples) % FLAGS.predict_batch_size != 0:
                predict_examples.append(PaddingInputExample())

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        file_based_convert_examples_to_features(predict_examples, label_list,
                                                FLAGS.max_seq_length, tokenizer,
                                                predict_file)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            label_length=label_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        result = estimator.predict(input_fn=predict_input_fn)

        output_score_value_file = os.path.join(FLAGS.output_dir, "predicate_score_value.txt")
        output_predicate_predict_file = os.path.join(FLAGS.output_dir, "predicate_predict.txt")
        with tf.gfile.GFile(output_score_value_file, "w") as score_value_writer:
            with tf.gfile.GFile(output_predicate_predict_file, "w") as predicate_predict_writer:
                num_written_lines = 0
                tf.logging.info("***** Predict results *****")
                for (i, prediction) in enumerate(result):
                    probabilities = prediction["probabilities"]
                    if i >= num_actual_predict_examples:
                        break
                    output_line_score_value = " ".join(
                        str(class_probability)
                        for class_probability in probabilities) + "\n"
                    predicate_predict = []
                    for idx, class_probability in enumerate(probabilities):
                        if class_probability > 0.5:
                            predicate_predict.append(label_list[idx])
                    output_line_predicate_predict = " ".join(predicate_predict) + "\n"
                    predicate_predict_writer.write(output_line_predicate_predict)
                    score_value_writer.write(output_line_score_value)
                    num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples
def train_eval_bert(m_type, data_path, train_data_file, dev_data_file, seq_len,
                    epoch):
    """ parameter setting """
    parameter_setting(m_type=m_type,
                      data_path=data_path,
                      train_file=train_data_file,
                      dev_file=dev_data_file,
                      max_seq_length=seq_len,
                      EPOCH=epoch)
    """ logging setting """
    tf.logging.set_verbosity(tf.logging.INFO)  # DEBUG,INFO,WARN,ERROR, FATAL
    """ processor """
    processor = DomainClf_Processor()
    """ tokenizer """
    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)
    """ bert config """
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    """ build output folder """
    tf.gfile.MakeDirs(FLAGS.model_output_dir)
    """ get label """
    label_list = processor.get_labels()
    """ tokenizer """
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)
    """ train examples"""
    train_examples = processor.get_train_examples(FLAGS.data_dir)

    num_batch_single_epoch = int(len(train_examples) / FLAGS.train_batch_size)
    num_warmup_steps = int(num_batch_single_epoch * FLAGS.warmup_proportion)
    print("num_warmup_steps", num_warmup_steps)
    """
    Train Setting
    """
    train_file = os.path.join(FLAGS.model_output_dir, "train.tf_record")
    file_based_convert_examples_to_features(train_examples, label_list,
                                            FLAGS.max_seq_length, tokenizer,
                                            train_file)

    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num examples = %d", len(train_examples))
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Num steps = %d", num_batch_single_epoch)

    train_input_fn = file_based_input_fn_builder(
        input_file=train_file,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    # estimator.train(input_fn=train_input_fn, max_steps=num_batch_single_epoch)
    """
    run setting
    """
    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=None,
        master=FLAGS.master,
        model_dir=FLAGS.model_output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host),
        session_config=config)
    """ model fn for estimator"""
    model_fn = model_fn_builder(bert_config=bert_config,
                                num_labels=len(label_list),
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate,
                                num_train_steps=num_batch_single_epoch,
                                num_warmup_steps=num_warmup_steps,
                                use_tpu=FLAGS.use_tpu,
                                use_one_hot_embeddings=FLAGS.use_tpu)
    """estimator setting """
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)
    """
    Eval Setting
    """
    """ eval examples """
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_eval_examples = len(eval_examples)

    eval_file = os.path.join(FLAGS.model_output_dir, "eval.tf_record")
    file_based_convert_examples_to_features(eval_examples, label_list,
                                            FLAGS.max_seq_length, tokenizer,
                                            eval_file)

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                    len(eval_examples), num_actual_eval_examples,
                    len(eval_examples) - num_actual_eval_examples)
    tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

    eval_input_fn = file_based_input_fn_builder(
        input_file=eval_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=False)
    """
    early stopping hook
    """
    early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
        estimator,
        metric_name='eval_accuracy',
        max_steps_without_decrease=FLAGS.iterations_per_loop * 3,
        min_steps=num_batch_single_epoch *
        1)  # min_step=num_batch_single_epoch*5)  # min_step
    """
    train and eval
    """
    train_spec = tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping])
    eval_spec = tf.estimator.EvalSpec(
        eval_input_fn, exporters=None,
        throttle_secs=1)  # throttle_secs 多少秒之内不进行 eval
    result_eval = tf.estimator.train_and_evaluate(estimator, train_spec,
                                                  eval_spec)

    print('\n\n result eval', result_eval)