예제 #1
0
def main(flags):
    # TODO: what's the input dimensionality?
    _INPUT_DIM = 4000

    # define a base config that has all the universal
    # properties
    base_config = model_config.ModelConfig(
        # max_batch_size=flags['batch_size'],
        input=[
            model_config.ModelInput(
                name='input',
                data_type=model_config.TYPE_FP32,
                dims=[flags['batch_size'], flags['in_channels'], _INPUT_DIM])
        ],
        output=[
            model_config.ModelOutput(name='output',
                                     data_type=model_config.TYPE_FP32,
                                     dims=[flags['batch_size'], _INPUT_DIM])
        ],
        instance_group=[
            model_config.ModelInstanceGroup(
                count=1, kind=model_config.ModelInstanceGroup.KIND_GPU)
        ])

    # do onnx export
    # start by copying config and setting up
    # modelstore directory
    onnx_config = model_config.ModelConfig(name='deepclean_onnx',
                                           platform='onnxruntime_onnx')
    onnx_config.MergeFrom(base_config)

    onnx_dir = os.path.join(flags['model_store_dir'], onnx_config.name)
    soft_makedirs(os.path.join(onnx_dir, '0'))

    # create dummy input and network and use to export onnx
    dummy_input = torch.randn(*base_config.input[0].dims)
    #   dynamic_axes = {}
    #   for i in onnx_config.input:
    #     dynamic_axes[i.name] = {0: 'batch'}
    #   for i in onnx_config.output:
    #     dynamic_axes[i.name] = {0: 'batch'}

    model = DeepClean()
    torch.onnx.export(
        model,
        dummy_input,
        os.path.join(onnx_dir, '0', 'model.onnx'),
        verbose=True,
        input_names=[i.name for i in onnx_config.input],
        output_names=[i.name for i in onnx_config.output],
        #     dynamic_axes=dynamic_axes
    )

    # write config
    with open(os.path.join(onnx_dir, 'config.pbtxt'), 'w') as f:
        f.write(str(onnx_config))

    # do trt conversion
    convert_to_tensorrt(flags['model_store_dir'], base_config, use_fp16=False)
    convert_to_tensorrt(flags['model_store_dir'], base_config, use_fp16=True)
예제 #2
0
def convert_to_tensorrt(model_store_dir, base_config, use_fp16=False):
    trt_config = model_config.ModelConfig(name='deepclean_trt' +
                                          ('_fp16' if use_fp16 else '_fp32'),
                                          platform='tensorrt_plan')
    trt_config.MergeFrom(base_config)

    trt_dir = os.path.join(model_store_dir, trt_config.name)
    soft_makedirs(os.path.join(trt_dir, '0'))

    # set up a plan builder
    TRT_LOGGER = trt.Logger()
    builder = trt.Builder(TRT_LOGGER)
    builder.max_workspace_size = 1 << 28  # 256 MiB
    builder.max_batch_size = 1  # flags['batch_size']
    if use_fp16:
        builder.fp16_mode = True
        builder.strict_type_constraints = True

#   config = builder.create_builder_config()
#   profile = builder.create_optimization_profile()
#   min_shape = tuple([1] + onnx_config.input[0].dims[1:])
#   max_shape = tuple([8] + onnx_config.input[0].dims[1:])

#   optimal_shape = max_shape
#   profile.set_shape('input', min_shape, optimal_shape, max_shape)
#   config.add_optimization_profile(profile)

# initialize a parser with a network and fill in that
# network with the onnx file we just saved
    network = builder.create_network(
        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    parser = trt.OnnxParser(network, TRT_LOGGER)
    onnx_path = os.path.join(model_store_dir, 'deepclean_onnx', '0',
                             'model.onnx')
    with open(onnx_path, 'rb') as f:
        parser.parse(f.read())


#   last_layer = network.get_layer(network.num_layers - 1)
#   if not last_layer.get_output(0):
#     network.mark_output(last_layer.get_output(0))

# build an engine from that network
    engine = builder.build_cuda_engine(network)
    with open(os.path.join(trt_dir, '0', 'model.plan'), 'wb') as f:
        f.write(engine.serialize())

    # export config
    with open(os.path.join(trt_dir, 'config.pbtxt'), 'w') as f:
        f.write(str(trt_config))
def run_client():
    """
    Ask a question of context on TRTIS.
    :param context: str
    :param question: str
    :param question_id: int
    :return:
    """

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    eval_examples = read_squad_examples(
        input_file=FLAGS.predict_file,
        is_training=False,
        version_2_with_negative=FLAGS.version_2_with_negative)

    eval_features = []

    def append_feature(feature):
        eval_features.append(feature)

    convert_examples_to_features(examples=eval_examples[0:],
                                 tokenizer=tokenizer,
                                 max_seq_length=FLAGS.max_seq_length,
                                 doc_stride=FLAGS.doc_stride,
                                 max_query_length=FLAGS.max_query_length,
                                 is_training=False,
                                 output_fn=append_feature)

    protocol_str = 'grpc'  # http or grpc
    url = FLAGS.trtis_server_url
    verbose = True
    model_name = FLAGS.trtis_model_name
    model_version = FLAGS.trtis_model_version
    batch_size = FLAGS.predict_batch_size

    protocol = ProtocolType.from_str(protocol_str)  # or 'grpc'

    ctx = InferContext(url, protocol, model_name, model_version, verbose)

    channel = grpc.insecure_channel(url)

    stub = grpc_service_pb2_grpc.GRPCServiceStub(channel)

    prof_request = grpc_service_pb2.server__status__pb2.model__config__pb2.ModelConfig(
    )

    prof_response = stub.Profile(prof_request)

    status_ctx = ServerStatusContext(url,
                                     protocol,
                                     model_name=model_name,
                                     verbose=verbose)

    model_config_pb2.ModelConfig()

    status_result = status_ctx.get_server_status()

    outstanding = {}
    max_outstanding = 20

    sent_prog = tqdm.tqdm(desc="Send Requests", total=len(eval_features))
    recv_prog = tqdm.tqdm(desc="Recv Requests", total=len(eval_features))

    def process_outstanding(do_wait):

        if (len(outstanding) == 0):
            return

        ready_id = ctx.get_ready_async_request(do_wait)

        if (ready_id is None):
            return

        # If we are here, we got an id
        result = ctx.get_async_run_results(ready_id, False)
        stop = time.time()

        if (result is None):
            raise ValueError(
                "Context returned null for async id marked as done")

        outResult = outstanding.pop(ready_id)

        time_list.append(stop - outResult.start_time)

        batch_count = len(outResult.inputs[label_id_key])

        for i in range(batch_count):
            unique_id = int(outResult.inputs[label_id_key][i][0])
            start_logits = [float(x) for x in result["start_logits"][i].flat]
            end_logits = [float(x) for x in result["end_logits"][i].flat]
            all_results.append(
                RawResult(unique_id=unique_id,
                          start_logits=start_logits,
                          end_logits=end_logits))

        recv_prog.update(n=batch_count)

    all_results = []
    time_list = []

    print("Starting Sending Requests....\n")

    all_results_start = time.time()

    for inputs_dict in batch(eval_features, batch_size):

        present_batch_size = len(inputs_dict[label_id_key])

        outputs_dict = {
            'start_logits': InferContext.ResultFormat.RAW,
            'end_logits': InferContext.ResultFormat.RAW
        }

        start = time.time()
        async_id = ctx.async_run(inputs_dict,
                                 outputs_dict,
                                 batch_size=present_batch_size)

        outstanding[async_id] = PendingResult(async_id=async_id,
                                              start_time=start,
                                              inputs=inputs_dict)

        sent_prog.update(n=present_batch_size)

        # Try to process at least one response per request
        process_outstanding(len(outstanding) >= max_outstanding)

    tqdm.tqdm.write(
        "All Requests Sent! Waiting for responses. Outstanding: {}.\n".format(
            len(outstanding)))

    # Now process all outstanding requests
    while (len(outstanding) > 0):
        process_outstanding(True)

    all_results_end = time.time()
    all_results_total = (all_results_end - all_results_start) * 1000.0

    print("-----------------------------")
    print("Individual Time Runs - Ignoring first two iterations")
    print("Total Time: {} ms".format(all_results_total))
    print("-----------------------------")

    print("-----------------------------")
    print("Total Inference Time = %0.2f for"
          "Sentences processed = %d" % (sum(time_list), len(eval_features)))
    print("Throughput Average (sentences/sec) = %0.2f" %
          (len(eval_features) / all_results_total * 1000.0))
    print("-----------------------------")

    time_list.sort()

    avg = np.mean(time_list)
    cf_95 = max(time_list[:int(len(time_list) * 0.95)])
    cf_99 = max(time_list[:int(len(time_list) * 0.99)])
    cf_100 = max(time_list[:int(len(time_list) * 1)])
    print("-----------------------------")
    print("Summary Statistics")
    print("Batch size =", FLAGS.predict_batch_size)
    print("Sequence Length =", FLAGS.max_seq_length)
    print("Latency Confidence Level 95 (ms) =", cf_95 * 1000)
    print("Latency Confidence Level 99 (ms)  =", cf_99 * 1000)
    print("Latency Confidence Level 100 (ms)  =", cf_100 * 1000)
    print("Latency Average (ms)  =", avg * 1000)
    print("-----------------------------")

    output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
    output_nbest_file = os.path.join(FLAGS.output_dir,
                                     "nbest_predictions.json")
    output_null_log_odds_file = os.path.join(FLAGS.output_dir,
                                             "null_odds.json")

    write_predictions(eval_examples, eval_features, all_results,
                      FLAGS.n_best_size, FLAGS.max_answer_length,
                      FLAGS.do_lower_case, output_prediction_file,
                      output_nbest_file, output_null_log_odds_file)
def run_client():
    """
    Ask a question of context on TRTIS.
    :param context: str
    :param question: str
    :param question_id: int
    :return:
    """

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    # -------------------------------------------------------------
    # Creation of examples here
    # -------------------------------------------------------------
    paragraph = """The koala (Phascolarctos cinereus, or, inaccurately, koala bear[a]) is an arboreal herbivorous marsupial native to Australia. It is the only extant representative of the family Phascolarctidae and its closest living relatives are the wombats, which comprise the family Vombatidae. The koala is found in coastal areas of the mainland's eastern and southern regions, inhabiting Queensland, New South Wales, Victoria, and South Australia. It is easily recognisable by its stout, tailless body and large head with round, fluffy ears and large, spoon-shaped nose. The koala has a body length of 60–85 cm (24–33 in) and weighs 4–15 kg (9–33 lb). Fur colour ranges from silver grey to chocolate brown. Koalas from the northern populations are typically smaller and lighter in colour than their counterparts further south. These populations possibly are separate subspecies, but this is disputed.
    """
    question_text = "Who is Koala?"
    examples = []
    example = SquadExample(
        qas_id=1,
        question_text=question_text,
        doc_tokens=convert_doc_tokens(paragraph_text=paragraph))
    for iterator in range(30):
        examples.append(example)

    # Switching from predict_file read to api-read
    # eval_examples = read_squad_examples(
    #     input_file=FLAGS.predict_file, is_training=False,
    #     version_2_with_negative=FLAGS.version_2_with_negative)
    eval_examples = examples
    eval_features = []

    def append_feature(feature):
        eval_features.append(feature)

    convert_examples_to_features(examples=eval_examples[0:],
                                 tokenizer=tokenizer,
                                 max_seq_length=FLAGS.max_seq_length,
                                 doc_stride=FLAGS.doc_stride,
                                 max_query_length=FLAGS.max_query_length,
                                 is_training=False,
                                 output_fn=append_feature)

    protocol_str = 'grpc'  # http or grpc
    url = FLAGS.trtis_server_url
    verbose = True
    model_name = FLAGS.trtis_model_name
    model_version = FLAGS.trtis_model_version
    batch_size = FLAGS.predict_batch_size

    protocol = ProtocolType.from_str(protocol_str)  # or 'grpc'

    ctx = InferContext(url, protocol, model_name, model_version, verbose)

    channel = grpc.insecure_channel(url)

    stub = grpc_service_pb2_grpc.GRPCServiceStub(channel)

    prof_request = grpc_service_pb2.server__status__pb2.model__config__pb2.ModelConfig(
    )

    prof_response = stub.Profile(prof_request)

    status_ctx = ServerStatusContext(url,
                                     protocol,
                                     model_name=model_name,
                                     verbose=verbose)

    model_config_pb2.ModelConfig()

    status_result = status_ctx.get_server_status()

    outstanding = {}
    max_outstanding = 20

    sent_prog = tqdm.tqdm(desc="Send Requests", total=len(eval_features))
    recv_prog = tqdm.tqdm(desc="Recv Requests", total=len(eval_features))

    def process_outstanding(do_wait):

        if (len(outstanding) == 0):
            return

        ready_id = ctx.get_ready_async_request(do_wait)

        if (ready_id is None):
            return

        # If we are here, we got an id
        result = ctx.get_async_run_results(ready_id, False)
        stop = time.time()

        if (result is None):
            raise ValueError(
                "Context returned null for async id marked as done")

        outResult = outstanding.pop(ready_id)

        time_list.append(stop - outResult.start_time)

        batch_count = len(outResult.inputs[label_id_key])

        for i in range(batch_count):
            unique_id = int(outResult.inputs[label_id_key][i][0])
            start_logits = [float(x) for x in result["start_logits"][i].flat]
            end_logits = [float(x) for x in result["end_logits"][i].flat]
            all_results.append(
                RawResult(unique_id=unique_id,
                          start_logits=start_logits,
                          end_logits=end_logits))

        recv_prog.update(n=batch_count)

    all_results = []
    time_list = []

    print("Starting Sending Requests....\n")

    all_results_start = time.time()

    for inputs_dict in batch(eval_features, batch_size):

        present_batch_size = len(inputs_dict[label_id_key])

        outputs_dict = {
            'start_logits': InferContext.ResultFormat.RAW,
            'end_logits': InferContext.ResultFormat.RAW
        }

        start = time.time()
        async_id = ctx.async_run(inputs_dict,
                                 outputs_dict,
                                 batch_size=present_batch_size)

        outstanding[async_id] = PendingResult(async_id=async_id,
                                              start_time=start,
                                              inputs=inputs_dict)

        sent_prog.update(n=present_batch_size)

        # Try to process at least one response per request
        process_outstanding(len(outstanding) >= max_outstanding)

    tqdm.tqdm.write(
        "All Requests Sent! Waiting for responses. Outstanding: {}.\n".format(
            len(outstanding)))

    # Now process all outstanding requests
    while (len(outstanding) > 0):
        process_outstanding(True)

    all_results_end = time.time()
    all_results_total = (all_results_end - all_results_start) * 1000.0

    print("-----------------------------")
    print("Individual Time Runs - Ignoring first two iterations")
    print("Total Time: {} ms".format(all_results_total))
    print("-----------------------------")

    print("-----------------------------")
    print("Total Inference Time = %0.2f for"
          "Sentences processed = %d" % (sum(time_list), len(eval_features)))
    print("Throughput Average (sentences/sec) = %0.2f" %
          (len(eval_features) / all_results_total * 1000.0))
    print("-----------------------------")

    time_list.sort()

    avg = np.mean(time_list)
    cf_95 = max(time_list[:int(len(time_list) * 0.95)])
    cf_99 = max(time_list[:int(len(time_list) * 0.99)])
    cf_100 = max(time_list[:int(len(time_list) * 1)])
    print("-----------------------------")
    print("Summary Statistics")
    print("Batch size =", FLAGS.predict_batch_size)
    print("Sequence Length =", FLAGS.max_seq_length)
    print("Latency Confidence Level 95 (ms) =", cf_95 * 1000)
    print("Latency Confidence Level 99 (ms)  =", cf_99 * 1000)
    print("Latency Confidence Level 100 (ms)  =", cf_100 * 1000)
    print("Latency Average (ms)  =", avg * 1000)
    print("-----------------------------")
예제 #5
0
    def generate_trt_config(
            save_path: Path,
            inputs: List[IOShape],
            outputs: List[IOShape],
            arch_name: str = 'model',
            platform: TensorRTPlatform = TensorRTPlatform.TENSORRT_PLAN,
            max_batch_size: int = 32,
            instance_group: List[ModelInstanceGroup] = None
    ):
        """Generate and save TensorRT inference server model configuration file: `model.pbtxt`.

        see here for more detailed configuration
            https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/protobuf_api/model_config.proto.html
        Arguments:
            save_path (Path): Model saving path name, generated by `modelci.hub.utils.generate_path`.
            inputs (List[IOShape]): Input tensors shape definition.
            outputs: (List[IOShape]): Output tensors shape definition.
            arch_name (str): Model architecture name。
            platform (TensorRTPlatform): TensorRT platform name.
            max_batch_size (int): Maximum batch size. This will be activated only when the first dimension of the input
                shape is -1. Otherwise, the Default to 32, indicating the max batch size will be determined
                by the first dimension of the input shape. The batch size from input shape will be suppressed when
                there is a value applied to this argument.
            instance_group (List[ModelInstanceGroup]): Model instance group (workers) definition. Default is to
                create a single instance loading on the first available CUDA device.
        """
        from tensorrtserver.api import model_config_pb2

        # assert batch size
        batch_sizes = list(map(lambda x: x.shape[0], inputs))
        if not all(batch_size == batch_sizes[0] for batch_size in batch_sizes):
            raise ValueError('batch size for inputs (i.e. the first dimensions of `input.shape` are not consistent.')
        if batch_sizes[0] != -1:
            max_batch_size = 0
            remove_batch_dim = False
        else:
            remove_batch_dim = True

        inputs = TRTConverter.build_model_inputs(inputs, remove_batch_dim=remove_batch_dim)
        outputs = TRTConverter.build_model_outputs(outputs, remove_batch_dim=remove_batch_dim)

        if instance_group is None:
            instance_group = [ModelInstanceGroup(kind=ModelInstanceGroupKind.KIND_GPU, count=1, gpus=[0])]

        config = ModelConfig(
            name=str(arch_name),
            platform=platform.name.lower(),
            version_policy=ModelVersionPolicy(),
            max_batch_size=max_batch_size,
            input=inputs,
            output=outputs,
            instance_group=instance_group,
        )

        with open(str(save_path / 'config.pbtxt'), 'w') as cfg:
            # to dict
            config_dict = config.to_dict(casing=Casing.SNAKE)
            # to pbtxt format string
            model_config_message = model_config_pb2.ModelConfig()
            pbtxt_str = str(json_format.ParseDict(config_dict, model_config_message))
            cfg.write(pbtxt_str)
예제 #6
0
def main(_):
    """
    Ask a question of context on Triton.
    :param context: str
    :param question: str
    :param question_id: int
    :return:
    """
    os.environ[
        "TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false"  #causes memory fragmentation for bert leading to OOM

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    # Get the Data
    if FLAGS.predict_file:
        eval_examples = read_squad_examples(
            input_file=FLAGS.predict_file,
            is_training=False,
            version_2_with_negative=FLAGS.version_2_with_negative)
    elif FLAGS.question and FLAGS.answer:
        input_data = [{
            "paragraphs": [{
                "context": FLAGS.context,
                "qas": [{
                    "id": 0,
                    "question": FLAGS.question
                }]
            }]
        }]

        eval_examples = read_squad_examples(
            input_file=None,
            is_training=False,
            version_2_with_negative=FLAGS.version_2_with_negative,
            input_data=input_data)
    else:
        raise ValueError(
            "Either predict_file or question+answer need to defined")

    # Get Eval Features = Preprocessing
    eval_features = []

    def append_feature(feature):
        eval_features.append(feature)

    convert_examples_to_features(examples=eval_examples[0:],
                                 tokenizer=tokenizer,
                                 max_seq_length=FLAGS.max_seq_length,
                                 doc_stride=FLAGS.doc_stride,
                                 max_query_length=FLAGS.max_query_length,
                                 is_training=False,
                                 output_fn=append_feature)

    protocol_str = 'grpc'  # http or grpc
    url = FLAGS.triton_server_url
    verbose = True
    model_name = FLAGS.triton_model_name
    model_version = FLAGS.triton_model_version
    batch_size = FLAGS.predict_batch_size

    protocol = ProtocolType.from_str(protocol_str)  # or 'grpc'

    ctx = InferContext(url, protocol, model_name, model_version, verbose)

    status_ctx = ServerStatusContext(url,
                                     protocol,
                                     model_name=model_name,
                                     verbose=verbose)

    model_config_pb2.ModelConfig()

    status_result = status_ctx.get_server_status()
    user_data = UserData()

    max_outstanding = 20
    # Number of outstanding requests
    outstanding = 0

    sent_prog = tqdm.tqdm(desc="Send Requests", total=len(eval_features))
    recv_prog = tqdm.tqdm(desc="Recv Requests", total=len(eval_features))

    def process_outstanding(do_wait, outstanding):

        if (outstanding == 0 or do_wait is False):
            return outstanding

        # Wait for deferred items from callback functions
        (infer_ctx, ready_id, idx, start_time,
         inputs) = user_data._completed_requests.get()

        if (ready_id is None):
            return outstanding

        # If we are here, we got an id
        result = ctx.get_async_run_results(ready_id)
        stop = time.time()

        if (result is None):
            raise ValueError(
                "Context returned null for async id marked as done")

        outstanding -= 1

        time_list.append(stop - start_time)

        batch_count = len(inputs[label_id_key])

        for i in range(batch_count):
            unique_id = int(inputs[label_id_key][i][0])
            start_logits = [float(x) for x in result["start_logits"][i].flat]
            end_logits = [float(x) for x in result["end_logits"][i].flat]
            all_results.append(
                RawResult(unique_id=unique_id,
                          start_logits=start_logits,
                          end_logits=end_logits))

        recv_prog.update(n=batch_count)
        return outstanding

    all_results = []
    time_list = []

    print("Starting Sending Requests....\n")

    all_results_start = time.time()
    idx = 0
    for inputs_dict in batch(eval_features, batch_size):

        present_batch_size = len(inputs_dict[label_id_key])

        outputs_dict = {
            'start_logits': InferContext.ResultFormat.RAW,
            'end_logits': InferContext.ResultFormat.RAW
        }

        start_time = time.time()
        ctx.async_run(partial(completion_callback, user_data, idx, start_time,
                              inputs_dict),
                      inputs_dict,
                      outputs_dict,
                      batch_size=present_batch_size)
        outstanding += 1
        idx += 1

        sent_prog.update(n=present_batch_size)

        # Try to process at least one response per request
        outstanding = process_outstanding(outstanding >= max_outstanding,
                                          outstanding)

    tqdm.tqdm.write(
        "All Requests Sent! Waiting for responses. Outstanding: {}.\n".format(
            outstanding))

    # Now process all outstanding requests
    while (outstanding > 0):
        outstanding = process_outstanding(True, outstanding)

    all_results_end = time.time()
    all_results_total = (all_results_end - all_results_start) * 1000.0

    print("-----------------------------")
    print("Total Time: {} ms".format(all_results_total))
    print("-----------------------------")

    print("-----------------------------")
    print("Total Inference Time = %0.2f for"
          "Sentences processed = %d" % (sum(time_list), len(eval_features)))
    print("Throughput Average (sentences/sec) = %0.2f" %
          (len(eval_features) / all_results_total * 1000.0))
    print("-----------------------------")

    if FLAGS.output_dir and FLAGS.predict_file:
        # When inferencing on a dataset, get inference statistics and write results to json file
        time_list.sort()

        avg = np.mean(time_list)
        cf_95 = max(time_list[:int(len(time_list) * 0.95)])
        cf_99 = max(time_list[:int(len(time_list) * 0.99)])
        cf_100 = max(time_list[:int(len(time_list) * 1)])
        print("-----------------------------")
        print("Summary Statistics")
        print("Batch size =", FLAGS.predict_batch_size)
        print("Sequence Length =", FLAGS.max_seq_length)
        print("Latency Confidence Level 95 (ms) =", cf_95 * 1000)
        print("Latency Confidence Level 99 (ms)  =", cf_99 * 1000)
        print("Latency Confidence Level 100 (ms)  =", cf_100 * 1000)
        print("Latency Average (ms)  =", avg * 1000)
        print("-----------------------------")

        output_prediction_file = os.path.join(FLAGS.output_dir,
                                              "predictions.json")
        output_nbest_file = os.path.join(FLAGS.output_dir,
                                         "nbest_predictions.json")
        output_null_log_odds_file = os.path.join(FLAGS.output_dir,
                                                 "null_odds.json")

        write_predictions(eval_examples, eval_features, all_results,
                          FLAGS.n_best_size, FLAGS.max_answer_length,
                          FLAGS.do_lower_case, output_prediction_file,
                          output_nbest_file, output_null_log_odds_file,
                          FLAGS.version_2_with_negative, FLAGS.verbose_logging)
    else:
        # When inferencing on a single example, write best answer to stdout
        all_predictions, all_nbest_json, scores_diff_json = get_predictions(
            eval_examples, eval_features, all_results, FLAGS.n_best_size,
            FLAGS.max_answer_length, FLAGS.do_lower_case,
            FLAGS.version_2_with_negative, FLAGS.verbose_logging)
        print(
            "Context is: %s \n\nQuestion is: %s \n\nPredicted Answer is: %s" %
            (FLAGS.context, FLAGS.question, all_predictions[0]))