def getPrediction(in_sentences):
  ret = []
  try:
        
      labels = [0, 1]
      input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
      input_features = run_classifier.file_based_convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer,'predict_features.TFRecord')
      !gsutil cp predict_features.TFRecord {DATA_CACHE}/predict_features.TFRecord
      predict_input_fn = run_classifier.file_based_input_fn_builder(input_file=DATA_CACHE + '/predict_features.TFRecord', seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=True)
      predictions = estimator.predict(predict_input_fn)
      #ret = [(sentence, prediction['probabilities'], labels[prediction['labels']]) for sentence, prediction in zip(in_sentences, predictions)]
      for p in predictions:
            ret.append(p)
  except IndexError:
      return(ret)
  return(ret)
Example #2
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "ylilauta": YlilautaProcessor,
        "yle": YleProcessor,
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_predict' must be True."
        )

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.gfile.MakeDirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size *
            FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    model_fn = model_fn_builder(bert_config=bert_config,
                                num_labels=len(label_list),
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate,
                                num_train_steps=num_train_steps,
                                num_warmup_steps=num_warmup_steps,
                                use_tpu=FLAGS.use_tpu,
                                use_one_hot_embeddings=FLAGS.use_tpu)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        file_based_convert_examples_to_features(train_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, train_file)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_eval:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        num_actual_eval_examples = len(eval_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on. These do NOT count towards the metric (all tf.metrics
            # support a per-instance weight, and these get a weight of 0.0).
            while len(eval_examples) % FLAGS.eval_batch_size != 0:
                eval_examples.append(PaddingInputExample())

        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        file_based_convert_examples_to_features(eval_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, eval_file)

        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(eval_examples), num_actual_eval_examples,
                        len(eval_examples) - num_actual_eval_examples)
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the
        # number of steps.
        if FLAGS.use_tpu:
            assert len(eval_examples) % FLAGS.eval_batch_size == 0
            eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

        eval_drop_remainder = True if FLAGS.use_tpu else False
        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)

        result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
                # output hyperparameter values and results
                print('\t'.join([
                    str(i) for i in [
                        'DEV-RESULT', 'init_checkpoint', FLAGS.init_checkpoint,
                        'data_dir', FLAGS.data_dir, 'max_seq_length',
                        FLAGS.max_seq_length, 'batch_size',
                        FLAGS.train_batch_size, 'learning_rate',
                        FLAGS.learning_rate, 'num_train_epochs',
                        FLAGS.num_train_epochs, key, result[key]
                    ]
                ]))

    if FLAGS.do_predict:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_predict_examples = len(predict_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on.
            while len(predict_examples) % FLAGS.predict_batch_size != 0:
                predict_examples.append(PaddingInputExample())

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        file_based_convert_examples_to_features(predict_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, predict_file)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        result = estimator.predict(input_fn=predict_input_fn)

        output_predict_file = os.path.join(FLAGS.output_dir,
                                           "test_results.tsv")
        with tf.gfile.GFile(output_predict_file, "w") as writer:
            num_written_lines = 0
            tf.logging.info("***** Predict results *****")
            for (i, prediction) in enumerate(result):
                probabilities = prediction["probabilities"]
                if i >= num_actual_predict_examples:
                    break
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in probabilities) + "\n"
                writer.write(output_line)
                num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples
Example #3
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "cola": run_classifier.ColaProcessor,
        "mnli": run_classifier.MnliProcessor,
        "mrpc": run_classifier.MrpcProcessor,
    }

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    tf.gfile.MakeDirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = create_tokenizer_from_hub_module(FLAGS.bert_hub_module_handle)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size *
            FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    model_fn = model_fn_builder(
        num_labels=len(label_list),
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        bert_hub_module_handle=FLAGS.bert_hub_module_handle)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        train_features = run_classifier.convert_examples_to_features(
            train_examples, label_list, FLAGS.max_seq_length, tokenizer)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = run_classifier.input_fn_builder(
            features=train_features,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_eval:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        eval_features = run_classifier.convert_examples_to_features(
            eval_examples, label_list, FLAGS.max_seq_length, tokenizer)

        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d", len(eval_examples))
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the
        # number of steps.
        if FLAGS.use_tpu:
            # Eval will be slightly WRONG on the TPU because it will truncate
            # the last batch.
            eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size)

        eval_drop_remainder = True if FLAGS.use_tpu else False
        eval_input_fn = run_classifier.input_fn_builder(
            features=eval_features,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)

        result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    if FLAGS.do_predict:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        if FLAGS.use_tpu:
            # Discard batch remainder if running on TPU
            n = len(predict_examples)
            predict_examples = predict_examples[:(
                n - n % FLAGS.predict_batch_size)]

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        run_classifier.file_based_convert_examples_to_features(
            predict_examples, label_list, FLAGS.max_seq_length, tokenizer,
            predict_file)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d", len(predict_examples))
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_input_fn = run_classifier.file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=FLAGS.use_tpu)

        result = estimator.predict(input_fn=predict_input_fn)

        output_predict_file = os.path.join(FLAGS.output_dir,
                                           "test_results.tsv")
        with tf.gfile.GFile(output_predict_file, "w") as writer:
            tf.logging.info("***** Predict results *****")
            for prediction in result:
                probabilities = prediction["probabilities"]
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in probabilities) + "\n"
                writer.write(output_line)
Example #4
0
def classify(
    data_dir:str,#The input data dir. Should contain the .tsv files (or other data files) 
    bert_config_file:str,#The config json file corresponding to the pre-trained BERT model.
    vocab_file:str,#The vocabulary file that the BERT model was trained on.
    output_dir:str,#The output directory where the model checkpoints will be written.

    #unnecessary parameters
    task_name:str='customer_task',
    labels:list=None,#a list of all labels
    init_checkpoint:str=None,#Initial checkpoint (usually from a pre-trained BERT model).
    do_lower_case:bool=True,#Whether to lower case the input text. 
    #Should be True for uncased models and False for cased models.

    max_seq_length:int=128,#The maximum total input sequence length after WordPiece tokenization. 
    #Sequences longer than this will be truncated, and sequences shorter than this will be padded.

    do_train:bool=False,#Whether to run training.
    do_eval:bool=False,#Whether to run eval on the dev set.
    do_predict:bool=False,#Whether to run the model in inference mode on the test set.
    train_batch_size:int=32,#Total batch size for training.
    eval_batch_size:int=8,#Total batch size for eval.
    predict_batch_size:int=5,#Total batch size for predict.
    learning_rate:float=5e-5,#The initial learning rate for Adam.
    num_train_epochs:float=3.0,#Total number of training epochs to perform.
    warmup_proportion:float=0.1,#Proportion of training to perform linear learning rate warmup for. 
    #E.g., 0.1 = 10% of training.

    save_checkpoints_steps:int=1000,#How often to save the model checkpoint.
    iterations_per_loop:int=1000,#How many steps to make in each estimator call.
    use_tpu:bool=False,#Whether to use TPU or GPU/CPU.
    tpu_name:str=None,#The Cloud TPU to use for training. This should be either the name 
    #used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.

    tpu_zone:str=None,#[Optional] GCE zone where the Cloud TPU is located in. If not 
    #specified, we will attempt to automatically detect the GCE project from metadata.

    gcp_project:str=None,#[Optional] Project name for the Cloud TPU-enabled project. If not 
    #specified, we will attempt to automatically detect the GCE project from metadata.

    master:str=None,#[Optional] TensorFlow master URL.
    num_tpu_cores:int=8#Only used if `use_tpu` is True. Total number of TPU cores to use.
    ):
    tf.logging.set_verbosity(tf.logging.INFO)
    processors={
        "cola": brc.ColaProcessor,
        "mnli": brc.MnliProcessor,
        "mrpc": brc.MrpcProcessor,
        "xnli": brc.XnliProcessor,
    }
    #设置参数do_lower_case和init_checkpoint
    bert.tokenization.validate_case_matches_checkpoint(do_lower_case,init_checkpoint)
    #参数do_train、do_eval和do_predict至少有其一为真
    if not do_train and not do_eval and not do_predict:
        raise ValueError("At least one of 'do_train', 'do_eval' or 'do_predict' must be True.")
    #加载BERT模型的配置文件
    bert_config=bert.modeling.BertConfig.from_json_file(bert_config_file)
    #最大序列长度
    if max_seq_length>bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model was only trained up to sequence length %d"%
            (max_seq_length,bert_config.max_position_embeddings))
    #确保输出文件夹存在
    tf.gfile.MakeDirs(output_dir)
    #根据任务名选择相应的处理类
    taskName=task_name.lower()
    if taskName=='customer_task':
        processor=CUserLabelTaskProcessor(labels)
    else:
        if taskName not in processors:
            raise ValueError('Task not found: %s'%(taskName))
        processor=processors[taskName]()
    label_list=processor.get_labels()
    #print(label_list)
    #分词器设置
    tokenizer=bert.tokenization.FullTokenizer(vocab_file=vocab_file,do_lower_case=do_lower_case)
    #显卡相关
    tpu_cluster_resolver=None
    if use_tpu and tpu_name:
        tpu_cluster_resolver=tf.contrib.cluster_resolver.TPUClusterResolver(tpu_name,zone=gcp_project)
    is_per_host=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config=tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=master,
        model_dir=output_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=iterations_per_loop,
            num_shards=num_tpu_cores,
            per_host_input_for_training=is_per_host))
    #训练参数
    train_examples=None
    num_train_steps=None
    num_warmup_steps=None
    if do_train:
        train_examples=processor.get_train_examples(data_dir)
        num_train_steps=int(len(train_examples)/train_batch_size*num_train_epochs)
        num_warmup_steps=int(num_train_steps*warmup_proportion)
    model_fn=brc.model_fn_builder(
        bert_config=bert_config,
        num_labels=len(label_list),
        init_checkpoint=init_checkpoint,
        learning_rate=learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=use_tpu,
        use_one_hot_embeddings=use_tpu)
    #使用TPU
    # If TPU is not available, this will fall back to normal Estimator on CPU or GPU.
    estimator=tf.contrib.tpu.TPUEstimator(
        use_tpu=use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=train_batch_size,
        eval_batch_size=eval_batch_size,
        predict_batch_size=predict_batch_size)
    if do_train:
        train_file=os.path.join(output_dir,"train.tf_record")
        brc.file_based_convert_examples_to_features(train_examples, label_list, max_seq_length, tokenizer, train_file)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Batch size = %d", train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn=brc.file_based_input_fn_builder(input_file=train_file,seq_length=max_seq_length,is_training=True,drop_remainder=True)
        estimator.train(input_fn=train_input_fn,max_steps=num_train_steps)
    if do_eval:
        eval_examples=processor.get_dev_examples(data_dir)
        num_actual_eval_examples=len(eval_examples)
        if use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on. These do NOT count towards the metric (all tf.metrics
            # support a per-instance weight, and these get a weight of 0.0).
            while len(eval_examples)%eval_batch_size!=0:
                eval_examples.append(brc.PaddingInputExample())
        eval_file=os.path.join(output_dir,"eval.tf_record")
        brc.file_based_convert_examples_to_features(eval_examples, label_list, max_seq_length, tokenizer, eval_file)
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",len(eval_examples), num_actual_eval_examples,len(eval_examples) - num_actual_eval_examples)
        tf.logging.info("  Batch size = %d", eval_batch_size)
        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the number of steps.
        if use_tpu:
            assert len(eval_example)%eval_batch_size==0
            eval_steps=int(len(eval_examples)//eval_batch_size)
        eval_drop_remainder = True if use_tpu else False
        eval_input_fn=brc.file_based_input_fn_builder(
            input_file=eval_file,seq_length=max_seq_length,is_training=False,drop_remainder=eval_drop_remainder)
        result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
        output_eval_file = os.path.join(output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
    if do_predict:
        predict_examples = processor.get_test_examples(data_dir)
        num_actual_predict_examples = len(predict_examples)
        if use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on.
            while len(predict_examples) % predict_batch_size != 0:
                 predict_examples.append(brc.PaddingInputExample())
        predict_file = os.path.join(output_dir, "predict.tf_record")
        brc.file_based_convert_examples_to_features(predict_examples,label_list,max_seq_length,tokenizer,predict_file)
        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",len(predict_examples),
                       num_actual_predict_examples,len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", predict_batch_size)
        predict_drop_remainder = True if use_tpu else False
        predict_input_fn = brc.file_based_input_fn_builder(
            input_file=predict_file,seq_length=max_seq_length,is_training=False,drop_remainder=predict_drop_remainder)
        result = estimator.predict(input_fn=predict_input_fn)
        output_predict_file = os.path.join(output_dir, "test_results.tsv")
        with tf.gfile.GFile(output_predict_file, "w") as writer:
            num_written_lines = 0
            tf.logging.info("***** Predict results *****")
            for (i, prediction) in enumerate(result):
                probabilities = prediction["probabilities"]
                if i >= num_actual_predict_examples:
                    break
                output_line = "\t".join(str(class_probability)for class_probability in probabilities) + "\n"
                writer.write(output_line)
                num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
          "cola": run_classifier.ColaProcessor,
          "mnli": run_classifier.MnliProcessor,
          "mrpc": run_classifier.MrpcProcessor,
          "xnli": run_classifier.XnliProcessor,
          "type": TypeProcessor,
      }
    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                FLAGS.init_checkpoint)
    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError("At least one of `do_train`, `do_eval` or `do_predict' must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
                "Cannot use sequence length %d because the BERT model "
                "was only trained up to sequence length %d" %
                (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.gfile.MakeDirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(
            vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
                len(train_examples) / FLAGS.batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    model_fn = model_fn_builder(
            bert_config=bert_config,
            num_labels=len(label_list),
            init_checkpoint=FLAGS.init_checkpoint,
            learning_rate=FLAGS.learning_rate,
            num_train_steps=num_train_steps,
            num_warmup_steps=num_warmup_steps,
            use_tpu=FLAGS.use_tpu,
            use_one_hot_embeddings=FLAGS.use_tpu)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    run_config = tf.estimator.RunConfig(
        model_dir=FLAGS.output_dir,
        save_summary_steps=FLAGS.save_sumary_steps,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps)
    
    #estimator = tf.estimator.Estimator(
    #        model_fn=model_fn,
    #        config=run_config,
    #        train_batch_size=FLAGS.train_batch_size,
    #        eval_batch_size=FLAGS.eval_batch_size,
    #        predict_batch_size=FLAGS.predict_batch_size)
    #
    estimator = tf.estimator.Estimator(
            model_fn=model_fn,
            config=run_config,
            params={"batch_size" : FLAGS.batch_size})
    if FLAGS.do_train:
        train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        run_classifier.file_based_convert_examples_to_features(
                train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Batch size = %d", FLAGS.batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = run_classifier.file_based_input_fn_builder(
                input_file=train_file,
                seq_length=FLAGS.max_seq_length,
                is_training=True,
                drop_remainder=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_eval:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        num_actual_eval_examples = len(eval_examples)

        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        run_classifier.file_based_convert_examples_to_features(
                eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)

        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                                        len(eval_examples), num_actual_eval_examples,
                                        len(eval_examples) - num_actual_eval_examples)
        tf.logging.info("  Batch size = %d", FLAGS.batch_size)

        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the
        # number of steps.
        
        eval_drop_remainder = True if FLAGS.use_tpu else False
        eval_input_fn = run_classifier.file_based_input_fn_builder(
                input_file=eval_file,
                seq_length=FLAGS.max_seq_length,
                is_training=False,
                drop_remainder=eval_drop_remainder)

        result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    if FLAGS.do_predict:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_predict_examples = len(predict_examples)
        

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        run_classifier.file_based_convert_examples_to_features(predict_examples, label_list,
                                                                                        FLAGS.max_seq_length, tokenizer,
                                                                                        predict_file)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                                        len(predict_examples), num_actual_predict_examples,
                                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = run_classifier.file_based_input_fn_builder(
                input_file=predict_file,
                seq_length=FLAGS.max_seq_length,
                is_training=False,
                drop_remainder=predict_drop_remainder)

        result = estimator.predict(input_fn=predict_input_fn)

        output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
        with tf.gfile.GFile(output_predict_file, "w") as writer:
            num_written_lines = 0
            tf.logging.info("***** Predict results *****")
            for (i, prediction) in enumerate(result):
                probabilities = prediction["probabilities"]
                if i >= num_actual_predict_examples:
                    break
                output_line = "\t".join(
                        str(class_probability)
                        for class_probability in probabilities) + "\n"
                writer.write(output_line)
                num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples