def __init__(self, args):
        from tensorflow.estimator import RunConfig, Estimator
        # load parameters
        self.layer_indexes = args.layer_indexes
        self.ckpt_name = args.ckpt_name
        self.config_name = args.config_name
        self.vocab_file = args.vocab_file
        self.do_lower_case = args.do_lower_case
        self.batch_size = args.batch_size
        self.max_seq_len = args.max_seq_len
        self.gpu_memory_fraction = args.gpu_memory_fraction
        self.xla = args.xla

        # load bert config & construct
        tf.logging.info("load bert config & construct ...")
        self.bert_config = modeling.BertConfig.from_json_file(self.config_name)
        model_fn = model_fn_builder(bert_config=self.bert_config,
                                    init_checkpoint=self.ckpt_name,
                                    layer_indexes=self.layer_indexes)

        # construct estimator
        tf.logging.info("load estimator ...")
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = self.gpu_memory_fraction
        config.log_device_placement = False
        if self.xla:
            config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
        self.estimator = Estimator(model_fn=model_fn,
                                   config=RunConfig(session_config=config),
                                   params={'batch_size': self.batch_size})

        self.tokenizer = tokenization.FullTokenizer(
            vocab_file=self.vocab_file, do_lower_case=self.do_lower_case)
        tf.logging.info("initialization done.")
Пример #2
0
def _remove_metrics(estimator: tf_estimator.Estimator,
                    metrics_to_remove: Union[List[str], Callable[[str], bool]]):
  """Modifies the Estimator to make its model_fn return less metrics in EVAL.

  Note that this only removes the metrics from the
  EstimatorSpec.eval_metric_ops. It does not remove them from the graph or
  undo any side-effects that they might have had (e.g. modifications to
  METRIC_VARIABLES collections).

  This is useful for when you use py_func, streaming metrics, or other metrics
  incompatible with TFMA in your trainer. To keep these metrics in your trainer
  (so they still show up in Tensorboard) and still use TFMA, you can call
  remove_metrics on your Estimator before calling export_eval_savedmodel.

  This is a context manager, so it can be used like:
    with _remove_metrics(estimator, ['streaming_auc']):
      tfma.export.export_eval_savedmodel(estimator, ...)

  Args:
    estimator: tf.estimator.Estimator to modify. Will be mutated in place.
    metrics_to_remove: List of names of metrics to remove.

  Yields:
    Nothing.
  """
  old_call_model_fn = estimator._call_model_fn  # pylint: disable=protected-access

  def wrapped_call_model_fn(unused_self, features, labels, mode, config):
    result = old_call_model_fn(features, labels, mode, config)
    if mode == tf_estimator.ModeKeys.EVAL:
      filtered_eval_metric_ops = {}
      for k, v in result.eval_metric_ops.items():
        if isinstance(metrics_to_remove, collections.Iterable):
          if k in metrics_to_remove:
            continue
        elif callable(metrics_to_remove):
          if metrics_to_remove(k):
            continue
        filtered_eval_metric_ops[k] = v
      result = result._replace(eval_metric_ops=filtered_eval_metric_ops)
    return result

  estimator._call_model_fn = types.MethodType(  # pylint: disable=protected-access
      wrapped_call_model_fn, estimator)

  yield

  estimator._call_model_fn = old_call_model_fn  # pylint: disable=protected-access
Пример #3
0
def _create_estimator(
        num_gpus=1,
        params=DynamicBatchSizeParams(),
        model=None):
    if model is None:
        model = BertMultiTask(params=params)
    model_fn = model.get_model_fn(warm_start=False)

    dist_trategy = tf.contrib.distribute.MirroredStrategy(
        num_gpus=int(num_gpus),
        cross_tower_ops=tf.contrib.distribute.AllReduceCrossDeviceOps(
            'nccl', num_packs=int(num_gpus)))

    run_config = tf.estimator.RunConfig(
        train_distribute=dist_trategy,
        eval_distribute=dist_trategy,
        log_step_count_steps=params.log_every_n_steps)

    # ws = make_warm_start_setting(params)

    estimator = Estimator(
        model_fn,
        model_dir=params.ckpt_dir,
        params=params,
        config=run_config)
    return estimator
Пример #4
0
def ImageVectors(session):
    log('Inferring image vectors of registered images.')
    model_fn, input_fn, _ = get_img2vec_fns(session, mode=ModeKeys.PREDICT)
    estimator = Estimator(model_fn,
                          model_dir=str(get_models_home() / 'imgvecs'))
    ids = sorted([int(i) for i in os.listdir(get_data_home() / 'images')])

    for imgid, vec in tqdm(zip(ids, estimator.predict(input_fn)),
                           unit=' vecs',
                           desc='Image vectors',
                           total=len(ids)):
        row = Vector(id=imgid, vec=vec)
        session.merge(row)
        if (imgid % 1800) == 0:
            session.flush()
    session.commit()
Пример #5
0
    def get_estimator(self):
        from tensorflow.estimator import Estimator
        from tensorflow.estimator import RunConfig
        from tensorflow.estimator import EstimatorSpec

        def model_fn(features, labels, mode, params):
            with tf.gfile.GFile(self.graph_path, 'rb') as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())

            input_names = ['input_ids', 'input_mask', 'segment_ids']

            output = tf.import_graph_def(
                graph_def,
                input_map={k + ':0': features[k]
                           for k in input_names},
                return_elements=['final_encodes:0'])

            return EstimatorSpec(mode=mode, predictions={'encodes': output[0]})

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = self.gpu_memory_fraction
        config.log_device_placement = False
        config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

        return Estimator(model_fn=model_fn,
                         config=RunConfig(session_config=config),
                         params={'batch_size': self.batch_size})
Пример #6
0
def save(estimator: Estimator, saved_model_dir: str) -> None:
    """Save a Tensorflow estimator"""
    with TemporaryDirectory() as temporary_model_base_dir:
        export_dir = estimator.export_saved_model(temporary_model_base_dir,
                                                  _serving_input_receiver_fn)

        Path(saved_model_dir).mkdir(exist_ok=True)
        export_path = Path(export_dir.decode()).absolute()
        for path in export_path.glob('*'):
            shutil.move(str(path), saved_model_dir)
Пример #7
0
    def build(self, model_fn_args, config_args):
        config = self.get_config(**config_args)
        model_fn = self.get_model_fn(**model_fn_args)

        self.estimator = Estimator(
            model_fn=model_fn,
            config=config,
            params={'batch_size': self.batch_size})

        self.tokenizer = tokenization.FullTokenizer(
            vocab_file=self.vocab_file, do_lower_case=self.do_lower_case)
Пример #8
0
def ImageEmbedding(session):
    log('Starting training of image embedding.')
    model_fn, input_fn, reciever_fn = get_img2vec_fns(session)
    estimator = Estimator(model_fn,
                          model_dir=str(get_models_home() / 'imgvecs'))
    tf.logging.set_verbosity(tf.logging.INFO)
    estimator.train(input_fn)
    estimator.export_savedmodel(str(get_models_home() / 'imgvecs'),
                                reciever_fn)
Пример #9
0
def _create_estimator(num_gpus=1, params=DynamicBatchSizeParams(), model=None):
    if model is None:
        model = BertMultiTask(params=params)
    model_fn = model.get_model_fn(warm_start=True)

    dist_trategy = tf.distribute.MirroredStrategy()

    run_config = tf.estimator.RunConfig(
        train_distribute=dist_trategy,
        eval_distribute=dist_trategy,
        log_step_count_steps=params.log_every_n_steps)

    # ws = make_warm_start_setting(params)

    estimator = Estimator(model_fn,
                          model_dir=params.ckpt_dir,
                          params=params,
                          config=run_config)
    return estimator
Пример #10
0
 def _initialize_estimator(self, feature_provider, **kwargs):
     self.aux_config["_feature_provider"] = feature_provider.name
     if not self.aux_config.get("class_labels"):
         class_labels = map(str, sorted(feature_provider.class_labels))
         self.aux_config["class_labels"] = list(class_labels)
     self.params["_n_out_classes"] = len(self.aux_config["class_labels"])
     self.params.update(feature_provider.embedding_params)
     #! steps==0 or epochs==0 => repeats indefinitely,
     #! => Need to perform checks against None type, not falsy values
     self.params["epochs"] = (
         kwargs.get("epochs")
         if kwargs.get("epochs") is not None
         else self.params.get("epochs")
     )
     epoch_steps, num_training_samples = feature_provider.steps_per_epoch(
         self.params["batch-size"]
     )
     self.params["epoch_steps"] = (
         self.params.get("epoch_steps") or epoch_steps
     )
     self.params["shuffle_buffer"] = (
         self.params.get("shuffle_buffer") or num_training_samples
     )
     if self.params.get("epochs") is not None:
         steps = (
             self.params["epoch_steps"] * self.params["epochs"]
         ) or None
     elif (
         kwargs.get("steps") is not None
         or self.params.get("steps") is not None
     ):
         steps = (kwargs.get("steps") or self.params.get("steps")) or None
     else:
         raise ValueError("No steps or epochs specified")
     self._estimator = Estimator(
         model_fn=self._model_fn, params=self.params, config=self.run_config
     )
     return steps
Пример #11
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "xnli": XnliProcessor,
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_predict' must be True."
        )

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.gfile.MakeDirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2

    dist_strategy = tf.contrib.distribute.MirroredStrategy(
        num_gpus=FLAGS.n_gpus,
        cross_device_ops=AllReduceCrossDeviceOps('nccl',
                                                 num_packs=FLAGS.n_gpus),
        # cross_device_ops=AllReduceCrossDeviceOps('hierarchical_copy'),
    )
    log_every_n_steps = 8
    run_config = RunConfig(
        train_distribute=dist_strategy,
        # eval_distribute=dist_strategy,
        log_step_count_steps=log_every_n_steps,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps)

    # run_config = tf.contrib.tpu.RunConfig(
    #     cluster=tpu_cluster_resolver,
    #     master=FLAGS.master,
    #     model_dir=FLAGS.output_dir,
    #     save_checkpoints_steps=FLAGS.save_checkpoints_steps,
    #     tpu_config=tf.contrib.tpu.TPUConfig(
    #         iterations_per_loop=FLAGS.iterations_per_loop,
    #         num_shards=FLAGS.num_tpu_cores,
    #         per_host_input_for_training=is_per_host))

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size *
            FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    model_fn = model_fn_builder(bert_config=bert_config,
                                num_labels=len(label_list),
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate,
                                num_train_steps=num_train_steps,
                                num_warmup_steps=num_warmup_steps,
                                use_tpu=FLAGS.use_tpu,
                                use_one_hot_embeddings=FLAGS.use_tpu)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = Estimator(model_fn=model_fn, params={}, config=run_config)
    # estimator = tf.contrib.tpu.TPUEstimator(
    #     use_tpu=FLAGS.use_tpu,
    #     model_fn=model_fn,
    #     config=run_config,
    #     train_batch_size=FLAGS.train_batch_size,
    #     eval_batch_size=FLAGS.eval_batch_size,
    #     predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        file_based_convert_examples_to_features(train_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, train_file)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_eval:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        num_actual_eval_examples = len(eval_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on. These do NOT count towards the metric (all tf.metrics
            # support a per-instance weight, and these get a weight of 0.0).
            while len(eval_examples) % FLAGS.eval_batch_size != 0:
                eval_examples.append(PaddingInputExample())

        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        file_based_convert_examples_to_features(eval_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, eval_file)

        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(eval_examples), num_actual_eval_examples,
                        len(eval_examples) - num_actual_eval_examples)
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the
        # number of steps.
        if FLAGS.use_tpu:
            assert len(eval_examples) % FLAGS.eval_batch_size == 0
            eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

        eval_drop_remainder = True if FLAGS.use_tpu else False
        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)

        result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    if FLAGS.do_predict:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_predict_examples = len(predict_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on.
            while len(predict_examples) % FLAGS.predict_batch_size != 0:
                predict_examples.append(PaddingInputExample())

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        file_based_convert_examples_to_features(predict_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, predict_file)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        result = estimator.predict(input_fn=predict_input_fn)

        output_predict_file = os.path.join(FLAGS.output_dir,
                                           "test_results.tsv")
        with tf.gfile.GFile(output_predict_file, "w") as writer:
            num_written_lines = 0
            tf.logging.info("***** Predict results *****")
            for (i, prediction) in enumerate(result):
                probabilities = prediction["probabilities"]
                if i >= num_actual_predict_examples:
                    break
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in probabilities) + "\n"
                writer.write(output_line)
                num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples
Пример #12
0
def main(_):

    if not os.path.exists('tmp'):
        os.mkdir('tmp')

    if FLAGS.model_dir:
        base_dir, dir_name = os.path.split(FLAGS.model_dir)
    else:
        base_dir, dir_name = None, None

    params = BaseParams()
    params.assign_problem(FLAGS.problem,
                          gpu=int(FLAGS.gpu),
                          base_dir=base_dir,
                          dir_name=dir_name)

    tf.logging.info('Checkpoint dir: %s' % params.ckpt_dir)
    time.sleep(3)

    model = BertMultiTask(params=params)
    model_fn = model.get_model_fn(warm_start=False)

    dist_trategy = tf.contrib.distribute.MirroredStrategy(
        num_gpus=int(FLAGS.gpu),
        cross_tower_ops=tf.contrib.distribute.AllReduceCrossDeviceOps(
            'nccl', num_packs=int(FLAGS.gpu)))

    run_config = tf.estimator.RunConfig(
        train_distribute=dist_trategy,
        eval_distribute=dist_trategy,
        log_step_count_steps=params.log_every_n_steps)

    # ws = make_warm_start_setting(params)

    estimator = Estimator(model_fn,
                          model_dir=params.ckpt_dir,
                          params=params,
                          config=run_config)

    if FLAGS.schedule == 'train':
        train_hook = RestoreCheckpointHook(params)

        def train_input_fn():
            return train_eval_input_fn(params)

        estimator.train(train_input_fn,
                        max_steps=params.train_steps,
                        hooks=[train_hook])

        def input_fn():
            return train_eval_input_fn(params, mode='eval')

        estimator.evaluate(input_fn=input_fn)
        params.to_json()

    elif FLAGS.schedule == 'eval':

        params.from_json()
        evaluate_func = getattr(metrics, FLAGS.eval_scheme + '_evaluate')
        print(evaluate_func(FLAGS.problem, estimator, params))

    elif FLAGS.schedule == 'predict':

        def input_fn():
            return predict_input_fn([
                '''兰心餐厅\n作为一个无辣不欢的妹子,对上海菜的偏清淡偏甜真的是各种吃不惯。
            每次出门和闺蜜越饭局都是避开本帮菜。后来听很多朋友说上海有几家特别正宗味道做
            的很好的餐厅于是这周末和闺蜜们准备一起去尝一尝正宗的本帮菜。\n进贤路是我在上
            海比较喜欢的一条街啦,这家餐厅就开在这条路上。已经开了三十多年的老餐厅了,地
            方很小,就五六张桌子。但是翻桌率比较快。二楼之前的居民间也改成了餐厅,但是在
            上海的名气却非常大。烧的就是家常菜,普通到和家里烧的一样,生意非常好,外面排
            队的比里面吃的人还要多。'''
            ],
                                    params,
                                    mode=PREDICT)

        pred = estimator.predict(input_fn=input_fn)
        for p in pred:
            print(p)
Пример #13
0
    # Training sub-graph
    global_step = tf.train.get_global_step()
    optimizer = tf.train.AdamOptimizer().minimize(loss)
    train = tf.group(optimizer.minimize(loss), tf.assign_add(global_step, 1))
    # ModelFnOps connects subgraphs we built to the
    # appropriate functionality.
    
    return tf.contrib.learn.ModelFnOps(mode=mode,
                                       predictions=y,
                                       loss=loss,
                                       train_op=train)




x_train, y_train, x_test, y_test = load_and_split_data()

feature_columns = [tf.contrib.layers.real_valued_column("x", dimension=9600)]

estimator = Estimator(model_fn= model,model_dir='./model/')

input_fn = tf.contrib.learn.io.numpy_input_fn({"x": x_train}, y_train)
eval_input_fn = tf.contrib.learn.io.numpy_input_fn({"x": x_test}, y_test)

# train
estimator.fit(input_fn=input_fn, steps=2000)
# Here we evaluate how well our model did. 
train_loss = estimator.evaluate(input_fn=input_fn)
eval_loss = estimator.evaluate(input_fn=eval_input_fn)
print("train loss: %r"% train_loss)
print("eval loss: %r"% eval_loss)
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    tf.gfile.MakeDirs(FLAGS.output_dir)

    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.gfile.Glob(input_pattern))

    tf.logging.info("*** Input Files ***")
    for input_file in input_files:
        tf.logging.info("  %s" % input_file)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2

    dist_strategy = tf.contrib.distribute.MirroredStrategy(
        num_gpus=FLAGS.n_gpus,
        cross_device_ops=AllReduceCrossDeviceOps('nccl',
                                                 num_packs=FLAGS.n_gpus),
        # cross_device_ops=AllReduceCrossDeviceOps('hierarchical_copy'),
    )
    log_every_n_steps = 8
    run_config = RunConfig(
        train_distribute=dist_strategy,
        # eval_distribute=dist_strategy,
        log_step_count_steps=log_every_n_steps,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps)

    model_fn = model_fn_builder(bert_config=bert_config,
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate,
                                num_train_steps=FLAGS.num_train_steps,
                                num_warmup_steps=FLAGS.num_warmup_steps,
                                use_tpu=FLAGS.use_tpu,
                                use_one_hot_embeddings=FLAGS.use_tpu)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = Estimator(model_fn=model_fn, params={}, config=run_config)

    if FLAGS.do_train:
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        train_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=True)
        estimator.train(input_fn=train_input_fn,
                        max_steps=FLAGS.num_train_steps)

    if FLAGS.do_eval:
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        eval_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=False)

        result = estimator.evaluate(input_fn=eval_input_fn,
                                    steps=FLAGS.max_eval_steps)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
class BertEncoder:
    def __init__(self, args):
        from tensorflow.estimator import RunConfig, Estimator
        # load parameters
        self.layer_indexes = args.layer_indexes
        self.ckpt_name = args.ckpt_name
        self.config_name = args.config_name
        self.vocab_file = args.vocab_file
        self.do_lower_case = args.do_lower_case
        self.batch_size = args.batch_size
        self.max_seq_len = args.max_seq_len
        self.gpu_memory_fraction = args.gpu_memory_fraction
        self.xla = args.xla

        # load bert config & construct
        tf.logging.info("load bert config & construct ...")
        self.bert_config = modeling.BertConfig.from_json_file(self.config_name)
        model_fn = model_fn_builder(bert_config=self.bert_config,
                                    init_checkpoint=self.ckpt_name,
                                    layer_indexes=self.layer_indexes)

        # construct estimator
        tf.logging.info("load estimator ...")
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = self.gpu_memory_fraction
        config.log_device_placement = False
        if self.xla:
            config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
        self.estimator = Estimator(model_fn=model_fn,
                                   config=RunConfig(session_config=config),
                                   params={'batch_size': self.batch_size})

        self.tokenizer = tokenization.FullTokenizer(
            vocab_file=self.vocab_file, do_lower_case=self.do_lower_case)
        tf.logging.info("initialization done.")

    def encode(self, input_sentences):
        return [
            sen['result']['layer_output_pooler']
            for sen in self._predict(input_sentences)
        ]

    def _predict(self, input_sentences):
        examples = self.read_examples(input_sentences)
        features = convert_examples_to_features(examples=examples,
                                                seq_length=self.max_seq_len,
                                                tokenizer=self.tokenizer)
        unique_id_to_feature = {}
        for feature in features:
            unique_id_to_feature[feature.unique_id] = feature

        input_fn = input_fn_builder(features=features,
                                    seq_length=self.max_seq_len)

        outputs_json = []
        for result in self.estimator.predict(input_fn,
                                             yield_single_examples=True):
            unique_id = int(result["unique_id"])
            feature = unique_id_to_feature[unique_id]
            output_json = collections.OrderedDict()
            output_json["linex_index"] = unique_id
            all_features = []
            for (i, token) in enumerate(feature.tokens):
                all_layers = []
                for (j, layer_index) in enumerate(self.layer_indexes):
                    layer_output = result["layer_output_%d" % j]
                    layers = collections.OrderedDict()
                    layers["index"] = layer_index
                    layers["values"] = [
                        round(float(x), 6)
                        for x in layer_output[i:(i + 1)].flat
                    ]
                    all_layers.append(layers)
                features = collections.OrderedDict()
                features["token"] = token
                features["layers"] = all_layers
                all_features.append(features)
            output_json["features"] = all_features
            output_json["features_pooler"] = result["layer_output_pooler"]
            output_json["result"] = result
            outputs_json.append(output_json)
        return outputs_json

    def read_examples(self, input_sentences):
        """Read a list of `InputExample`s from a list of sentence instead of an input file."""
        examples = []
        unique_id = 0
        for line in input_sentences:
            line = line.strip()
            text_a = None
            text_b = None
            m = re.match(r"^(.*) \|\|\| (.*)$", line)
            if m is None:
                text_a = line
            else:
                text_a = m.group(1)
                text_b = m.group(2)
            examples.append(
                InputExample(unique_id=unique_id, text_a=text_a,
                             text_b=text_b))
            unique_id += 1
        return examples