예제 #1
0
def mrc():
    data_from_post = getData()
    data = preprocess_data(data_from_post)

    eval_writer = mainfile.FeatureWriter(filename=os.path.join(
        mrc_inference_config["output_dir"], "eval.tf_record"),
                                         is_training=False)
    eval_features = []

    def append_feature(feature):
        eval_features.append(feature)
        eval_writer.process_feature(feature)

    mainfile.convert_examples_to_features(
        examples=data,
        tokenizer=tokenizer,
        max_seq_length=mrc_inference_config["max_seq_length"],
        doc_stride=mrc_inference_config["doc_stride"],
        max_query_length=mrc_inference_config["max_query_length"],
        is_training=False,
        output_fn=append_feature)
    eval_writer.close()

    all_results = []

    predict_input_fn = mainfile.input_fn_builder(
        input_file=eval_writer.filename,
        seq_length=mrc_inference_config["max_seq_length"],
        is_training=False,
        drop_remainder=False)

    all_results = []
    for result in estimator.predict(predict_input_fn,
                                    yield_single_examples=True):
        unique_id = int(result["unique_ids"])
        start_logits = [float(x) for x in result["start_logits"].flat]
        end_logits = [float(x) for x in result["end_logits"].flat]
        all_results.append(
            mainfile.RawResult(unique_id=unique_id,
                               start_logits=start_logits,
                               end_logits=end_logits))

    answer = mainfile.write_predictions(
        data, eval_features, all_results, 20,
        mrc_inference_config["max_answer_length"], True, None, None, None)
    return sendResponse({"Answer": answer.get(data_from_post.get("qas_id"))})
def _validate_squad(args, model, tokenizer):
    eval_examples = run_squad.read_squad_examples(
        input_file=args.predict_file,
        is_training=False,
        version_2_with_negative=args.version_2_with_negative)

    eval_features = run_squad.convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        doc_stride=args.doc_stride,
        max_query_length=args.max_query_length,
        is_training=False)

    run_squad.logger.info("***** Running predictions *****")
    run_squad.logger.info("  Num orig examples = %d", len(eval_examples))
    run_squad.logger.info("  Num split examples = %d", len(eval_features))
    run_squad.logger.info("  Batch size = %d", args.predict_batch_size)

    all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                   dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
    eval_data = run_squad.TensorDataset(all_input_ids, all_input_mask,
                                        all_segment_ids, all_example_index)
    # Run prediction for full data
    eval_sampler = run_squad.SequentialSampler(eval_data)
    eval_dataloader = run_squad.DataLoader(eval_data,
                                           sampler=eval_sampler,
                                           batch_size=args.predict_batch_size)

    model.eval()
    all_results = []
    run_squad.logger.info("Start evaluating")
    for input_ids, input_mask, segment_ids, example_indices in run_squad.tqdm(
            eval_dataloader, desc="Evaluating"):
        if len(all_results) % 1000 == 0:
            run_squad.logger.info("Processing example: %d" %
                                  (len(all_results)))
        input_ids = input_ids.cuda()
        input_mask = input_mask.cuda()
        segment_ids = segment_ids.cuda()
        with torch.no_grad():
            batch_start_logits, batch_end_logits = model(
                input_ids, segment_ids, input_mask)
        for i, example_index in enumerate(example_indices):
            start_logits = batch_start_logits[i].detach().cpu().tolist()
            end_logits = batch_end_logits[i].detach().cpu().tolist()
            eval_feature = eval_features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            all_results.append(
                run_squad.RawResult(unique_id=unique_id,
                                    start_logits=start_logits,
                                    end_logits=end_logits))
    output_prediction_file = os.path.join("predictions.json")
    output_nbest_file = os.path.join("nbest_predictions.json")
    output_null_log_odds_file = os.path.join("null_odds.json")
    run_squad.write_predictions(
        eval_examples, eval_features, all_results, args.n_best_size,
        args.max_answer_length, args.do_lower_case, output_prediction_file,
        output_nbest_file, output_null_log_odds_file, args.verbose_logging,
        args.version_2_with_negative, args.null_score_diff_threshold)

    result = _calc_metric_squad(args.predict_file, output_prediction_file)
    os.remove(output_prediction_file)
    os.remove(output_nbest_file)
    os.remove(output_null_log_odds_file)
    return result  # {'exact_match': exact_match, 'f1': f1}
예제 #3
0
def main(_):
    """
    Ask a question of context on Triton.
    :param context: str
    :param question: str
    :param question_id: int
    :return:
    """
    os.environ[
        "TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false"  #causes memory fragmentation for bert leading to OOM

    tf.compat.v1.logging.info("***** Configuaration *****")
    for key in FLAGS.__flags.keys():
        tf.compat.v1.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
    tf.compat.v1.logging.info("**************************")

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    # Get the Data
    if FLAGS.question and FLAGS.context:
        input_data = [{
            "paragraphs": [{
                "context": FLAGS.context,
                "qas": [{
                    "id": 0,
                    "question": FLAGS.question
                }]
            }]
        }]
        eval_examples = read_squad_examples(
            input_file=None,
            is_training=False,
            version_2_with_negative=FLAGS.version_2_with_negative,
            input_data=input_data)
    elif FLAGS.predict_file:
        eval_examples = read_squad_examples(
            input_file=FLAGS.predict_file,
            is_training=False,
            version_2_with_negative=FLAGS.version_2_with_negative)
    else:
        raise ValueError(
            "Either predict_file or question+answer need to defined")

    # Get Eval Features = Preprocessing
    eval_features = []

    def append_feature(feature):
        eval_features.append(feature)

    convert_examples_to_features(examples=eval_examples,
                                 tokenizer=tokenizer,
                                 max_seq_length=FLAGS.max_seq_length,
                                 doc_stride=FLAGS.doc_stride,
                                 max_query_length=FLAGS.max_query_length,
                                 is_training=False,
                                 output_fn=append_feature)

    protocol_str = 'grpc'  # http or grpc
    url = FLAGS.triton_server_url
    verbose = False
    model_name = FLAGS.triton_model_name
    model_version = str(FLAGS.triton_model_version)
    batch_size = FLAGS.predict_batch_size

    triton_client = tritongrpcclient.InferenceServerClient(url, verbose)
    model_metadata = triton_client.get_model_metadata(
        model_name=model_name, model_version=model_version)
    model_config = triton_client.get_model_config(model_name=model_name,
                                                  model_version=model_version)

    user_data = UserData()

    max_outstanding = 20
    # Number of outstanding requests
    outstanding = 0

    sent_prog = tqdm.tqdm(desc="Send Requests", total=len(eval_features))
    recv_prog = tqdm.tqdm(desc="Recv Requests", total=len(eval_features))

    def process_outstanding(do_wait, outstanding):

        if (outstanding == 0 or do_wait is False):
            return outstanding

        # Wait for deferred items from callback functions
        (result, error, idx, start_time,
         inputs) = user_data._completed_requests.get()

        if (result is None):
            return outstanding

        stop = time.time()

        if (error is not None):
            raise ValueError(
                "Context returned null for async id marked as done")

        outstanding -= 1

        time_list.append(stop - start_time)

        batch_count = len(inputs[label_id_key])
        if FLAGS.trt_engine:
            cls_squad_logits = result.as_numpy("cls_squad_logits")
            try:  #when batch size > 1
                start_logits_results = np.array(
                    cls_squad_logits.squeeze()[:, :, 0])
                end_logits_results = np.array(cls_squad_logits.squeeze()[:, :,
                                                                         1])
            except:
                start_logits_results = np.expand_dims(np.array(
                    cls_squad_logits.squeeze()[:, 0]),
                                                      axis=0)
                end_logits_results = np.expand_dims(np.array(
                    cls_squad_logits.squeeze()[:, 1]),
                                                    axis=0)
        else:
            start_logits_results = result.as_numpy("start_logits")
            end_logits_results = result.as_numpy("end_logits")
        for i in range(batch_count):
            unique_id = int(inputs[label_id_key][i][0])
            start_logits = [float(x) for x in start_logits_results[i].flat]
            end_logits = [float(x) for x in end_logits_results[i].flat]
            all_results.append(
                RawResult(unique_id=unique_id,
                          start_logits=start_logits,
                          end_logits=end_logits))

        recv_prog.update(n=batch_count)
        return outstanding

    all_results = []
    time_list = []

    print("Starting Sending Requests....\n")

    all_results_start = time.time()
    idx = 0
    for inputs_dict in batch(eval_features, batch_size):

        present_batch_size = len(inputs_dict[label_id_key])

        if not FLAGS.trt_engine:
            label_ids_data = np.stack(inputs_dict[label_id_key])
        input_ids_data = np.stack(inputs_dict['input_ids'])
        input_mask_data = np.stack(inputs_dict['input_mask'])
        segment_ids_data = np.stack(inputs_dict['segment_ids'])

        inputs = []
        inputs.append(
            tritongrpcclient.InferInput('input_ids', input_ids_data.shape,
                                        "INT32"))
        inputs[0].set_data_from_numpy(input_ids_data)
        inputs.append(
            tritongrpcclient.InferInput('input_mask', input_mask_data.shape,
                                        "INT32"))
        inputs[1].set_data_from_numpy(input_mask_data)
        inputs.append(
            tritongrpcclient.InferInput('segment_ids', segment_ids_data.shape,
                                        "INT32"))
        inputs[2].set_data_from_numpy(segment_ids_data)
        if not FLAGS.trt_engine:
            inputs.append(
                tritongrpcclient.InferInput(label_id_key, label_ids_data.shape,
                                            "INT32"))
            inputs[3].set_data_from_numpy(label_ids_data)

        outputs = []
        if FLAGS.trt_engine:
            outputs.append(
                tritongrpcclient.InferRequestedOutput('cls_squad_logits'))
        else:
            outputs.append(
                tritongrpcclient.InferRequestedOutput('start_logits'))
            outputs.append(tritongrpcclient.InferRequestedOutput('end_logits'))

        start_time = time.time()
        triton_client.async_infer(model_name,
                                  inputs,
                                  partial(completion_callback, user_data, idx,
                                          start_time, inputs_dict),
                                  request_id=str(idx),
                                  model_version=model_version,
                                  outputs=outputs)
        outstanding += 1
        idx += 1

        sent_prog.update(n=present_batch_size)

        # Try to process at least one response per request
        outstanding = process_outstanding(outstanding >= max_outstanding,
                                          outstanding)

    tqdm.tqdm.write(
        "All Requests Sent! Waiting for responses. Outstanding: {}.\n".format(
            outstanding))

    # Now process all outstanding requests
    while (outstanding > 0):
        outstanding = process_outstanding(True, outstanding)

    all_results_end = time.time()
    all_results_total = (all_results_end - all_results_start) * 1000.0

    print("-----------------------------")
    print("Total Time: {} ms".format(all_results_total))
    print("-----------------------------")

    print("-----------------------------")
    print("Total Inference Time = %0.2f for"
          "Sentences processed = %d" % (sum(time_list), len(eval_features)))
    print("Throughput Average (sentences/sec) = %0.2f" %
          (len(eval_features) / all_results_total * 1000.0))
    print("-----------------------------")

    if FLAGS.output_dir and FLAGS.predict_file:
        # When inferencing on a dataset, get inference statistics and write results to json file
        time_list.sort()

        avg = np.mean(time_list)
        cf_95 = max(time_list[:int(len(time_list) * 0.95)])
        cf_99 = max(time_list[:int(len(time_list) * 0.99)])
        cf_100 = max(time_list[:int(len(time_list) * 1)])
        print("-----------------------------")
        print("Summary Statistics")
        print("Batch size =", FLAGS.predict_batch_size)
        print("Sequence Length =", FLAGS.max_seq_length)
        print("Latency Confidence Level 95 (ms) =", cf_95 * 1000)
        print("Latency Confidence Level 99 (ms)  =", cf_99 * 1000)
        print("Latency Confidence Level 100 (ms)  =", cf_100 * 1000)
        print("Latency Average (ms)  =", avg * 1000)
        print("-----------------------------")

        output_prediction_file = os.path.join(FLAGS.output_dir,
                                              "predictions.json")
        output_nbest_file = os.path.join(FLAGS.output_dir,
                                         "nbest_predictions.json")
        output_null_log_odds_file = os.path.join(FLAGS.output_dir,
                                                 "null_odds.json")

        write_predictions(eval_examples, eval_features, all_results,
                          FLAGS.n_best_size, FLAGS.max_answer_length,
                          FLAGS.do_lower_case, output_prediction_file,
                          output_nbest_file, output_null_log_odds_file,
                          FLAGS.version_2_with_negative, FLAGS.verbose_logging)
    else:
        # When inferencing on a single example, write best answer to stdout
        all_predictions, all_nbest_json, scores_diff_json = get_predictions(
            eval_examples, eval_features, all_results, FLAGS.n_best_size,
            FLAGS.max_answer_length, FLAGS.do_lower_case,
            FLAGS.version_2_with_negative, FLAGS.verbose_logging)
        print(
            "Context is: %s \n\nQuestion is: %s \n\nPredicted Answer is: %s" %
            (FLAGS.context, FLAGS.question, all_predictions[0]))
예제 #4
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  bert_config = rs.modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  rs.validate_flags_or_throw(bert_config)

  tf.gfile.MakeDirs(FLAGS.output_dir)

  tokenizer = rs.tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      tpu_config=tf.contrib.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host))

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None
  if FLAGS.do_train:
    train_examples = rs.read_squad_examples(
        input_file=FLAGS.train_file, is_training=True)
    num_train_steps = int(
        len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    # Pre-shuffle the input to avoid having to make a very large shuffle
    # buffer in in the `input_fn`.
    rng = random.Random(12345)
    rng.shuffle(train_examples)

  model_fn = rs.model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = tf.contrib.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)

  if FLAGS.do_train:
    # We write to a temporary file to avoid storing very large constant tensors
    # in memory.
    train_writer = rs.FeatureWriter(
        filename=os.path.join(FLAGS.output_dir, "train.tf_record"),
        is_training=True)
    rs.convert_examples_to_features(
        examples=train_examples,
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=True,
        output_fn=train_writer.process_feature)
    train_writer.close()

    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num orig examples = %d", len(train_examples))
    tf.logging.info("  Num split examples = %d", train_writer.num_features)
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Num steps = %d", num_train_steps)
    del train_examples

    train_input_fn = rs.input_fn_builder(
        input_file=train_writer.filename,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

  if FLAGS.do_predict:
    eval_examples = rs.read_squad_examples(
        input_file=FLAGS.predict_file, is_training=False)

    act_seq_len = get_act_seq_len(eval_examples, tokenizer, FLAGS.max_seq_length,
                    FLAGS.doc_stride, FLAGS.max_query_length)

    eval_writer = rs.FeatureWriter(
        filename=os.path.join(FLAGS.output_dir, "eval.tf_record"),
        is_training=False)
    eval_features = []

    def append_feature(feature):
      eval_features.append(feature)
      eval_writer.process_feature(feature)

    rs.convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=False,
        output_fn=append_feature)
    eval_writer.close()

    tf.logging.info("***** Running predictions *****")
    tf.logging.info("  Num orig examples = %d", len(eval_examples))
    tf.logging.info("  Num split examples = %d", len(eval_features))
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    all_results = []

    predict_input_fn = rs.input_fn_builder(
        input_file=eval_writer.filename,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=False)
    
    # If running eval on the TPU, you will need to specify the number of
    # steps.
    all_results = []
    for idx, result in enumerate(estimator.predict(
        predict_input_fn, yield_single_examples=True)):
      if len(all_results) % 1000 == 0:
        tf.logging.info("Processing example: %d" % (len(all_results)))
      unique_id = int(result["unique_ids"])
      start_logits = [float(x) for x in result["start_logits"].flat]
      end_logits = [float(x) for x in result["end_logits"].flat]
      all_results.append(
          rs.RawResult(
              unique_id=unique_id,
              start_logits=start_logits[:act_seq_len[idx]],
              end_logits=end_logits[:act_seq_len[idx]]))

    output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
    output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json")
    output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json")

    rs.write_predictions(eval_examples, eval_features, all_results,
                         FLAGS.n_best_size, FLAGS.max_answer_length,
                         FLAGS.do_lower_case, output_prediction_file,
                         output_nbest_file, output_null_log_odds_file)
예제 #5
0
def main(_):
    """
    Ask a question of context on Triton.
    :param context: str
    :param question: str
    :param question_id: int
    :return:
    """
    os.environ[
        "TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false"  #causes memory fragmentation for bert leading to OOM

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    # Get the Data
    if FLAGS.predict_file:
        eval_examples = read_squad_examples(
            input_file=FLAGS.predict_file,
            is_training=False,
            version_2_with_negative=FLAGS.version_2_with_negative)
    elif FLAGS.question and FLAGS.answer:
        input_data = [{
            "paragraphs": [{
                "context": FLAGS.context,
                "qas": [{
                    "id": 0,
                    "question": FLAGS.question
                }]
            }]
        }]

        eval_examples = read_squad_examples(
            input_file=None,
            is_training=False,
            version_2_with_negative=FLAGS.version_2_with_negative,
            input_data=input_data)
    else:
        raise ValueError(
            "Either predict_file or question+answer need to defined")

    # Get Eval Features = Preprocessing
    eval_features = []

    def append_feature(feature):
        eval_features.append(feature)

    convert_examples_to_features(examples=eval_examples[0:],
                                 tokenizer=tokenizer,
                                 max_seq_length=FLAGS.max_seq_length,
                                 doc_stride=FLAGS.doc_stride,
                                 max_query_length=FLAGS.max_query_length,
                                 is_training=False,
                                 output_fn=append_feature)

    protocol_str = 'grpc'  # http or grpc
    url = FLAGS.triton_server_url
    verbose = True
    model_name = FLAGS.triton_model_name
    model_version = FLAGS.triton_model_version
    batch_size = FLAGS.predict_batch_size

    protocol = ProtocolType.from_str(protocol_str)  # or 'grpc'

    ctx = InferContext(url, protocol, model_name, model_version, verbose)

    status_ctx = ServerStatusContext(url,
                                     protocol,
                                     model_name=model_name,
                                     verbose=verbose)

    model_config_pb2.ModelConfig()

    status_result = status_ctx.get_server_status()
    user_data = UserData()

    max_outstanding = 20
    # Number of outstanding requests
    outstanding = 0

    sent_prog = tqdm.tqdm(desc="Send Requests", total=len(eval_features))
    recv_prog = tqdm.tqdm(desc="Recv Requests", total=len(eval_features))

    def process_outstanding(do_wait, outstanding):

        if (outstanding == 0 or do_wait is False):
            return outstanding

        # Wait for deferred items from callback functions
        (infer_ctx, ready_id, idx, start_time,
         inputs) = user_data._completed_requests.get()

        if (ready_id is None):
            return outstanding

        # If we are here, we got an id
        result = ctx.get_async_run_results(ready_id)
        stop = time.time()

        if (result is None):
            raise ValueError(
                "Context returned null for async id marked as done")

        outstanding -= 1

        time_list.append(stop - start_time)

        batch_count = len(inputs[label_id_key])

        for i in range(batch_count):
            unique_id = int(inputs[label_id_key][i][0])
            start_logits = [float(x) for x in result["start_logits"][i].flat]
            end_logits = [float(x) for x in result["end_logits"][i].flat]
            all_results.append(
                RawResult(unique_id=unique_id,
                          start_logits=start_logits,
                          end_logits=end_logits))

        recv_prog.update(n=batch_count)
        return outstanding

    all_results = []
    time_list = []

    print("Starting Sending Requests....\n")

    all_results_start = time.time()
    idx = 0
    for inputs_dict in batch(eval_features, batch_size):

        present_batch_size = len(inputs_dict[label_id_key])

        outputs_dict = {
            'start_logits': InferContext.ResultFormat.RAW,
            'end_logits': InferContext.ResultFormat.RAW
        }

        start_time = time.time()
        ctx.async_run(partial(completion_callback, user_data, idx, start_time,
                              inputs_dict),
                      inputs_dict,
                      outputs_dict,
                      batch_size=present_batch_size)
        outstanding += 1
        idx += 1

        sent_prog.update(n=present_batch_size)

        # Try to process at least one response per request
        outstanding = process_outstanding(outstanding >= max_outstanding,
                                          outstanding)

    tqdm.tqdm.write(
        "All Requests Sent! Waiting for responses. Outstanding: {}.\n".format(
            outstanding))

    # Now process all outstanding requests
    while (outstanding > 0):
        outstanding = process_outstanding(True, outstanding)

    all_results_end = time.time()
    all_results_total = (all_results_end - all_results_start) * 1000.0

    print("-----------------------------")
    print("Total Time: {} ms".format(all_results_total))
    print("-----------------------------")

    print("-----------------------------")
    print("Total Inference Time = %0.2f for"
          "Sentences processed = %d" % (sum(time_list), len(eval_features)))
    print("Throughput Average (sentences/sec) = %0.2f" %
          (len(eval_features) / all_results_total * 1000.0))
    print("-----------------------------")

    if FLAGS.output_dir and FLAGS.predict_file:
        # When inferencing on a dataset, get inference statistics and write results to json file
        time_list.sort()

        avg = np.mean(time_list)
        cf_95 = max(time_list[:int(len(time_list) * 0.95)])
        cf_99 = max(time_list[:int(len(time_list) * 0.99)])
        cf_100 = max(time_list[:int(len(time_list) * 1)])
        print("-----------------------------")
        print("Summary Statistics")
        print("Batch size =", FLAGS.predict_batch_size)
        print("Sequence Length =", FLAGS.max_seq_length)
        print("Latency Confidence Level 95 (ms) =", cf_95 * 1000)
        print("Latency Confidence Level 99 (ms)  =", cf_99 * 1000)
        print("Latency Confidence Level 100 (ms)  =", cf_100 * 1000)
        print("Latency Average (ms)  =", avg * 1000)
        print("-----------------------------")

        output_prediction_file = os.path.join(FLAGS.output_dir,
                                              "predictions.json")
        output_nbest_file = os.path.join(FLAGS.output_dir,
                                         "nbest_predictions.json")
        output_null_log_odds_file = os.path.join(FLAGS.output_dir,
                                                 "null_odds.json")

        write_predictions(eval_examples, eval_features, all_results,
                          FLAGS.n_best_size, FLAGS.max_answer_length,
                          FLAGS.do_lower_case, output_prediction_file,
                          output_nbest_file, output_null_log_odds_file,
                          FLAGS.version_2_with_negative, FLAGS.verbose_logging)
    else:
        # When inferencing on a single example, write best answer to stdout
        all_predictions, all_nbest_json, scores_diff_json = get_predictions(
            eval_examples, eval_features, all_results, FLAGS.n_best_size,
            FLAGS.max_answer_length, FLAGS.do_lower_case,
            FLAGS.version_2_with_negative, FLAGS.verbose_logging)
        print(
            "Context is: %s \n\nQuestion is: %s \n\nPredicted Answer is: %s" %
            (FLAGS.context, FLAGS.question, all_predictions[0]))