Ejemplo n.º 1
0
def answer():
    if request.method == "POST":
        modelName = 'bert-large-uncased-whole-word-masking-finetuned-squad'
        tokenizer = BertTokenizer.from_pretrained(modelName)
        model = TFBertForQuestionAnswering.from_pretrained(modelName)

        f = open(r'D:\Yeni klasör\text.txt', 'r', encoding='utf-8')
        text = f.read()
        question = request.form.get("question")

        input_text = question + "[SEP]" + text
        input_ids = tokenizer.encode(input_text)

        input = tf.constant(input_ids)[None, :]

        token_type_ids = [
            0 if i <= input_ids.index(102) else 1
            for i in range(len(input_ids))
        ]

        answer = model(input,
                       token_type_ids=tf.convert_to_tensor([token_type_ids]))

        startScores = answer.start_logits
        endScores = answer.end_logits

        input_tokens = tokenizer.convert_ids_to_tokens(input_ids)

        startIdx = tf.math.argmax(startScores[0], 0).numpy()
        endIdx = tf.math.argmax(endScores[0], 0).numpy() + 1
        x = (" ".join(input_tokens[startIdx:endIdx]))
        return render_template("answer.html", data=x)
    else:
        return render_template("index.html")
def semantic_search(corpus_path, sentence):
    """
    Returns: a string containing the answer
    * If no answer is found, return None
    """
    files = os.listdir(corpus_path)
    files = [elem for elem in files if '.md' in elem]
    all_text = []
    for file in files:
        with open('ZendeskArticles/' + file, 'r', encoding='UTF-8') as f:
            f_line = f.read()
        all_text.append(f_line)
    url = "bert-large-uncased-whole-word-masking-finetuned-squad"
    tokenizer = BertTokenizer.from_pretrained(url)
    model = TFBertForQuestionAnswering.from_pretrained(url, return_dict=True)
    result = []
    for i in range(len(all_text)):
        r = modelResult(model, tokenizer, sentence, all_text[i], 0)
        if type(r) is list:
            r.append(files[i])
            result.append(r)
    aux_sort = sorted(result, key=lambda x: x[0], reverse=True)
    best_5 = aux_sort[:5]
    new_scores = []
    for elem in best_5:
        r = modelResult(model, tokenizer, sentence, elem[2], 1)
        if type(r) is list:
            r.append(elem[1])
            r.append(elem[3])
            new_scores.append(r)
    aux_sort = sorted(new_scores, key=lambda x: abs(x[0]), reverse=True)
    return aux_sort[0][1]
Ejemplo n.º 3
0
 def __init__(self, bert_squad_model='bert-large-uncased-whole-word-masking-finetuned-squad',
              bert_emb_model='bert-base-uncased'):
     self.model_name = bert_squad_model
     self.model = TFBertForQuestionAnswering.from_pretrained(self.model_name)
     self.tokenizer = BertTokenizer.from_pretrained(self.model_name)
     self.maxlen = 512
     self.te = tpp.TransformerEmbedding(bert_emb_model, layers=[-2])
Ejemplo n.º 4
0
 def test_TFBertModel(self):
     from transformers import BertTokenizer, TFBertForQuestionAnswering
     tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
     model = TFBertForQuestionAnswering.from_pretrained('bert-base-cased')
     question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
     input_dict = tokenizer(question, text, return_tensors='tf')
     spec, input_dict = self.spec_and_pad(input_dict)
     self.run_test(model, input_dict, input_signature=spec)
Ejemplo n.º 5
0
 def test_TFBertFineTunedSquadModel(self):
     from transformers import BertTokenizer, TFBertForQuestionAnswering
     name = "bert-large-uncased-whole-word-masking-finetuned-squad"
     tokenizer = BertTokenizer.from_pretrained(name)
     model = TFBertForQuestionAnswering.from_pretrained(name)
     question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
     input_dict = tokenizer(question, text, return_tensors='tf')
     spec, input_dict = self.spec_and_pad(input_dict)
     self.run_test(model, input_dict, input_signature=spec)
Ejemplo n.º 6
0
 def test_TFBertForQuestionAnswering(self):
     from transformers import BertTokenizer, TFBertForQuestionAnswering
     pretrained_weights = 'bert-base-uncased'
     tokenizer = BertTokenizer.from_pretrained(pretrained_weights)
     text, inputs, inputs_onnx = self._prepare_inputs(tokenizer)
     model = TFBertForQuestionAnswering.from_pretrained(pretrained_weights)
     predictions = model.predict(inputs)
     onnx_model = keras2onnx.convert_keras(model, model.name)
     self.assertTrue(
         run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx,
                          predictions, self.model_files))
Ejemplo n.º 7
0
    def copy_model_files(self):
        modified = False

        src_path = self.checkpoint_path

        d = None
        try:
            if not (self.git_path / "tf_model.h5").exists() or not (
                    self.git_path / "pytorch_model.bin").exists():
                if task.startswith("squad"):
                    d = TemporaryDirectory()
                    model = QASparseXP.compile_model(src_path,
                                                     dest_path=d.name)
                    model = optimize_model(model, "heads")
                    model.save_pretrained(d.name)
                    src_path = d.name
                else:
                    raise Exception(f"Unknown task {task}")

            if not (self.git_path / "tf_model.h5").exists():
                with TemporaryDirectory() as d2:
                    if task.startswith("squad"):
                        QASparseXP.final_fine_tune_bertarize(
                            src_path, d2, remove_head_pruning=True)
                    else:
                        raise Exception(f"Unknown task {task}")

                    tf_model = TFBertForQuestionAnswering.from_pretrained(
                        d2, from_pt=True)
                    tf_model.save_pretrained(self.git_path)
                    modified = True

            if not (self.git_path / "pytorch_model.bin").exists():
                model = BertForQuestionAnswering.from_pretrained(src_path)
                model.save_pretrained(self.git_path)
                modified = True

            FILES = "special_tokens_map.json", "tokenizer_config.json", "vocab.txt"
            for file in FILES:
                if not (self.git_path / file).exists():
                    shutil.copyfile(str(Path(src_path) / file),
                                    str(self.git_path / file))
                    modified = True

        finally:
            if d is not None:
                d.cleanup()

        # Reload the config, this may have been changed by compilation / optimization (pruned_heads, gelu_patch, layer_norm_patch)
        with (self.git_path / "config.json").open() as f:
            self.checkpoint_info["config"] = json.load(f)

        return modified
Ejemplo n.º 8
0
    def init_bert(self):
        """ Runs data processing scripts to turn raw data from (../raw) into
            cleaned data ready to be analyzed (saved in ../processed).
        """
        self.logger = logging.getLogger(__name__)
        self.logger.info('Initializing BERT model and tokenizer.')

        tokenizer = BertTokenizer.from_pretrained(
            'bert-large-uncased-whole-word-masking-finetuned-squad')
        model = TFBertForQuestionAnswering.from_pretrained(
            'bert-large-uncased-whole-word-masking-finetuned-squad')

        return model, tokenizer
def semantic_search(corpus_path, sentence):
    """
    performs semantic search on a corpus of documents
    Question: string containing the question to answer
    Reference: string containing the reference document
      from which to find the answer
    Returns: a string containing the answer
    If no answer is found, return None
    Your function should use the bert-uncased-tf2-qa model
      from the tensorflow-hub library
    Your function should use the pre-trained BertTokenizer,
      bert-large-uncased-whole-word-masking-finetuned-squad,
      from the transformers library
    """

    files = os.listdir(corpus_path)
    files = [elem for elem in files if '.md' in elem]
    all_text = []
    for file in files:
        with open('ZendeskArticles/' + file, 'r', encoding='UTF-8') as f:
            f_line = f.read()
        all_text.append(f_line)

    url = "bert-large-uncased-whole-word-masking-finetuned-squad"
    tokenizer = BertTokenizer.from_pretrained(url)
    model = TFBertForQuestionAnswering.from_pretrained(url)
    result = []
    for i in range(len(all_text)):
        r = modelResult(model, tokenizer, sentence, all_text[i], 0)
        if type(r) is list:
            r.append(files[i])
            result.append(r)

    aux_sort = sorted(result, key=lambda x: x[0], reverse=True)
    best_5 = aux_sort[:5]
    new_scores = []
    for elem in best_5:
        r = modelResult(model, tokenizer, sentence, elem[2], 1)
        if type(r) is list:
            r.append(elem[1])
            r.append(elem[3])
            new_scores.append(r)

    aux_sort = sorted(new_scores, key=lambda x: abs(x[0]), reverse=True)
    return aux_sort[0][1]
Ejemplo n.º 10
0
 def __init__(self, modelName="bert-large-uncased-whole-word-masking-finetuned-squad"):
     self.modelName = modelName
     self.model = TFBertForQuestionAnswering.from_pretrained(self.modelName)
     self.tokenizer = BertTokenizer.from_pretrained(modelName)
Ejemplo n.º 11
0
 def __init__(self):
     self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
     self.model = TFBertForQuestionAnswering.from_pretrained(
         'bert-large-uncased-whole-word-masking-finetuned-squad')
     print('QA init done')
Ejemplo n.º 12
0
import os
from transformers import TFBertForQuestionAnswering, BertTokenizer
import tensorflow as tf
from rake_nltk import Rake
import json
from .clauses import Clause
from .conditionmaps import conditions
from .column_types import get, Number, FuzzyString, Categorical, String
from .data_utils import data_utils

qa_model = TFBertForQuestionAnswering.from_pretrained(
    'bert-large-uncased-whole-word-masking-finetuned-squad')
qa_tokenizer = BertTokenizer.from_pretrained(
    'bert-large-uncased-whole-word-masking-finetuned-squad', padding=True)

import nltk
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
lem = lemmatizer.lemmatize


def extract_keywords_from_doc(doc, phrases=True, return_scores=False):
    if phrases:
        r = Rake()
        if isinstance(doc, (list, tuple)):
            r.extract_keywords_from_sentences(doc)
        else:
            r.extract_keywords_from_text(doc)
        if return_scores:
Ejemplo n.º 13
0
import tensorflow as tf
from transformers import AutoTokenizer, TFBertForQuestionAnswering
import matplotlib.pyplot as plt
import numpy as np
tokenizer = AutoTokenizer.from_pretrained("deepset/bert-base-cased-squad2",
                                          use_fast=True)
model = TFBertForQuestionAnswering.from_pretrained(
    "deepset/bert-base-cased-squad2", from_pt=True)


def get_answer_span(question, context, model, tokenizer):
    inputs = tokenizer.encode_plus(question,
                                   context,
                                   return_tensors="tf",
                                   add_special_tokens=True,
                                   max_length=512)
    answer_start_scores, answer_end_scores = model(inputs)
    answer_start = tf.argmax(answer_start_scores, axis=1,
                             dtype='float32').numpy()[0]
    answer_end = (tf.argmax(answer_end_scores, axis=1) + 1).numpy()[0]
    print(
        tokenizer.convert_tokens_to_string(
            inputs["input_ids"][0][answer_start:answer_end]))
    return answer_start, answer_end


def clean_tokens(gradients, tokens, token_types):
    """
  Clean the tokens and gradients gradients
  Remove "[CLS]","[CLR]", "[SEP]" tokens
  Reduce (mean) gradients values for tokens that are split ##
Ejemplo n.º 14
0
    def copy_model_files(self):
        modified = False
        from pytorch_block_sparse.util import BertHeadsPruner

        if not (self.git_path / "tf_model.h5").exists():
            tf_model = TFBertForQuestionAnswering.from_pretrained(
                self.src_path, from_pt=True)
            tf_model.save_pretrained(self.git_path)
            modified = True

        devel = True
        if not (self.git_path / "pytorch_model.bin").exists() or devel:
            model = BertForQuestionAnswering.from_pretrained(self.src_path)
            to_prune, head_count = BertHeadsPruner(model).get_pruned_heads()
            model.prune_heads(to_prune)
            config = model.config

            config.pruned_heads = to_prune
            self.report["sparsity"]["pruned_heads"] = to_prune

            config.block_size = [
                config.mask_block_rows, config.mask_block_cols
            ]
            KEYS_TO_DELETE = [
                "pruning_submethod", "shuffling_method", "in_shuffling_group",
                "out_shuffling_group"
            ]
            KEYS_TO_DELETE += [
                "ampere_mask_init", "ampere_pruning_method",
                "ampere_mask_scale", "mask_init", "mask_scale"
            ]
            KEYS_TO_DELETE += [
                "pruning_method", "mask_block_rows", "mask_block_cols",
                "gradient_checkpointing"
            ]
            KEYS_TO_DELETE += [
                "initializer_range", "intermediate_size",
                "hidden_dropout_prob", "layer_norm_eps"
            ]

            for key in KEYS_TO_DELETE:
                delattr(config, key)

            config.architectures = ["BertForQuestionAnswering"]
            config.name_or_path = f"{self.model_owner_name}/{self.model_name}"
            model.save_pretrained(self.git_path)
            modified = True
            self.report["sparsity"]["total_pruned_attention_heads"] = sum(
                [len(t) for t in to_prune.values()])
            self.report["sparsity"]["total_attention_heads"] = self.report[
                "config"]["num_attention_heads"] * self.report["config"][
                    "num_hidden_layers"]

        self.report["packaging"] = {}
        self.report["packaging"]["pytorch_final_file_size"] = os.stat(
            self.git_path / "pytorch_model.bin").st_size
        self.report["packaging"]["model_owner"] = self.model_owner_name
        self.report["packaging"][
            "model_name"] = f"{self.model_owner_name}/{self.model_name}"

        #PRODUCED_PATHES = ["dev-v1.1.json", "nbest_predictions_.json", "predictions_.json"]
        FILES = "special_tokens_map.json", "tokenizer_config.json", "vocab.txt"  #, "report.json"
        for file in FILES:
            if not (self.git_path / file).exists():
                shutil.copyfile(self.src_path / file, self.git_path / file)
                modified = True

        if not (self.git_path / "model_meta.json").exists() or devel:
            with (self.git_path / "model_meta.json").open("w") as file:
                report_string = pretty_json(self.report)
                file.write(report_string)
        else:
            self.report = json.loads(
                (self.git_path / "model_meta.json").open().read())

        return modified
def remove_none_values(example):
  return not None in example["start_positions"] or not None in example["end_positions"]

train_tf_dataset = train_tf_dataset.filter(remove_none_values, load_from_cache_file=False)
columns = ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions']
train_tf_dataset.set_format(type='tensorflow', columns=columns)
features = {x: train_tf_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.max_len]) for x in columns[:3]} 
labels = {"output_1": train_tf_dataset["start_positions"].to_tensor(default_value=0, shape=[None, 1])}
labels["output_2"] = train_tf_dataset["end_positions"].to_tensor(default_value=0, shape=[None, 1])
tfdataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(8)

# Let's load a pretrained TF2 Bert model and a simple optimizer
from transformers import TFBertForQuestionAnswering

model = TFBertForQuestionAnswering.from_pretrained("bert-base-cased")
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE, from_logits=True)
opt = tf.keras.optimizers.Adam(learning_rate=3e-5)
model.compile(optimizer=opt,
              loss={'output_1': loss_fn, 'output_2': loss_fn},
              loss_weights={'output_1': 1., 'output_2': 1.},
              metrics=['accuracy'])

# Now let's train our model

model.fit(tfdataset, epochs=1, steps_per_epoch=3)

"""# Metrics API

`nlp` also provides easy access and sharing of metrics.
Ejemplo n.º 16
0
def fine_tune_squad(**config):
    # Get checkpoint from Hugging Face's models repo or from our own checkpoint.
    bert_model_name = config.get("bert_model_name", None)
    pretrained_ckpt_path = config.get("pretrained_ckpt_path", None)  # This can be passed from both cli and config file.
    assert bert_model_name is not None or pretrained_ckpt_path is not None, \
        "SQuAD requires a pretrained model, either `bert_model_name` (via config file) or `pretrained_ckpt_path` " \
        "(via config file or `--pretrained-ckpt-path` command line argument) but none provided."
    assert (bert_model_name is not None and pretrained_ckpt_path is None) \
        or (bert_model_name is None and pretrained_ckpt_path is not None), \
        f"Only one checkpoint is accepted, but two provided: `bert_model_name`={bert_model_name}, " \
        f"and `pretrained_ckpt_path`={pretrained_ckpt_path}."
    if pretrained_ckpt_path is not None:
        bert_config_params = config["bert_config"]

    # Get required options
    micro_batch_size = config["micro_batch_size"]
    num_epochs = config["num_epochs"]
    optimizer_opts = config["optimizer_opts"]
    learning_rate = config['learning_rate']
    replicas = config["replicas"]
    grad_acc_steps_per_replica = config["grad_acc_steps_per_replica"]
    wandb_opts = config["wandb_opts"]
    use_outlining = config["use_outlining"]
    replace_layers = config["replace_layers"]
    enable_recomputation = config["enable_recomputation"]
    embedding_serialization_factor = config["embedding_serialization_factor"]
    optimizer_state_offchip = config["optimizer_state_offchip"]
    matmul_available_memory_proportion_per_pipeline_stage = config[
        "matmul_available_memory_proportion_per_pipeline_stage"]
    matmul_partials_type = config["matmul_partials_type"]
    pipeline_stages = config["pipeline_stages"]
    device_mapping = config["device_mapping"]
    global_batches_per_log = config["global_batches_per_log"]
    seed = config["seed"]
    cache_dir = config["cache_dir"]
    output_dir = config["output_dir"]

    # Get optional options
    save_ckpt_path = config.get("save_ckpt_path", Path(__file__).parent.joinpath("checkpoints").absolute())
    ckpt_every_n_steps_per_execution = config.get("ckpt_every_n_steps_per_execution", 2000)

    universal_run_name = config.get("name", f"{Path(config['config']).stem}-{wandb_opts['init']['name']}")
    universal_run_name += f"-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    print(f"Universal name for run: {universal_run_name}")
    set_random_seeds(seed)
    num_pipeline_stages = len(device_mapping)
    num_ipus_per_replicas = max(device_mapping) + 1
    num_ipus = replicas * num_ipus_per_replicas

    # Load training and validation data
    # =================================
    train_dataset, eval_dataset, num_train_samples, num_eval_samples, raw_datasets = get_squad_data(
        micro_batch_size,
        cache_dir
    )
    train_batch_config = BatchConfig(micro_batch_size=micro_batch_size,
                                     total_num_train_samples=num_epochs * num_train_samples.numpy(),
                                     num_replicas=replicas,
                                     gradient_accumulation_count=grad_acc_steps_per_replica,
                                     dataset_size=num_train_samples.numpy(),
                                     global_batches_per_log=global_batches_per_log,
                                     task=Task.OTHER)

    # Create model
    # ============
    policy = tf.keras.mixed_precision.Policy("float16")
    tf.keras.mixed_precision.set_global_policy(policy)
    strategy = create_ipu_strategy(num_ipus, enable_recomputation=enable_recomputation)
    with strategy.scope():
        # Instantiate the pretrained model given in the config.
        if bert_model_name is not None:
            model = TFBertForQuestionAnswering.from_pretrained(bert_model_name)
        else:
            bert_config = BertConfig(**bert_config_params, hidden_act=ipu.nn_ops.gelu)
            model = TFBertForQuestionAnswering(config=bert_config)

        # Convert subclass model to functional, expand main layers to enable pipelining, and replace some layers to
        # optimise performance.
        model = convert_tf_bert_model(
            model,
            train_dataset,
            post_process_bert_input_layer,
            replace_layers=replace_layers,
            use_outlining=use_outlining,
            embedding_serialization_factor=embedding_serialization_factor,
            rename_outputs={'tf.compat.v1.squeeze': 'start_positions', 'tf.compat.v1.squeeze_1': 'end_positions'}
        )
        # Load from pretrained checkpoint if requested.
        if pretrained_ckpt_path is not None:
            print(f"Attempting to load pretrained checkpoint from path {pretrained_ckpt_path}. "
                  f"This will overwrite the current weights")
            load_checkpoint_into_model(model, pretrained_ckpt_path)

        # Configure pipeline stages
        # =========================
        if num_pipeline_stages > 1:
            pipeline_assigner = PipelineStagesAssigner(PIPELINE_ALLOCATE_PREVIOUS, PIPELINE_NAMES)
            assignments = model.get_pipeline_stage_assignment()
            assignments = pipeline_assigner.assign_pipeline_stages(assignments, pipeline_stages)
            model.set_pipeline_stage_assignment(assignments)
            model.print_pipeline_stage_assignment_summary()
            poplar_options_per_pipeline_stage = get_poplar_options_per_pipeline_stage(
                num_ipus_per_replicas,
                device_mapping,
                matmul_available_memory_proportion_per_pipeline_stage,
                matmul_partials_type
            )
            model.set_pipelining_options(
                gradient_accumulation_steps_per_replica=grad_acc_steps_per_replica,
                pipeline_schedule=ipu.ops.pipelining_ops.PipelineSchedule.Grouped,
                device_mapping=device_mapping,
                offload_weight_update_variables=optimizer_state_offchip,
                forward_propagation_stages_poplar_options=poplar_options_per_pipeline_stage,
                backward_propagation_stages_poplar_options=poplar_options_per_pipeline_stage,
                recomputation_mode=ipu.pipelining_ops.RecomputationMode.RecomputeAndBackpropagateInterleaved,
            )

        # Compile the model for training
        # ==============================
        # Wrap loss in an out-feed queue.
        loss_outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue()
        qa_loss = wrap_loss_in_enqueuer(QuestionAnsweringLossFunction,
                                        loss_outfeed_queue,
                                        ["end_positions_loss", "start_positions_loss"])()
        # Define optimiser with polynomial decay learning rate.
        learning_rate['lr_schedule_params']['total_steps'] = train_batch_config.num_train_steps
        lr_outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue(outfeed_mode=ipu.ipu_outfeed_queue.IPUOutfeedMode.LAST)
        lr_scheduler = get_lr_scheduler(scheduler_name=learning_rate["lr_schedule"],
                                        schedule_params=learning_rate["lr_schedule_params"],
                                        queue=lr_outfeed_queue)
        # Prepare optimizer.
        outline_optimizer_apply_gradients = use_outlining
        optimizer = get_optimizer(
            optimizer_opts["name"],
            grad_acc_steps_per_replica,
            replicas,
            lr_scheduler,
            outline_optimizer_apply_gradients,
            weight_decay_rate=optimizer_opts["params"]["weight_decay_rate"],
        )
        # Compile the model.
        model.compile(
            optimizer=optimizer,
            loss={"end_positions": qa_loss, "start_positions": qa_loss},
            metrics='accuracy',
            steps_per_execution=train_batch_config.steps_per_execution
        )

        # Train the model
        # ===============
        # Set up callbacks
        callbacks = CallbackFactory.get_callbacks(
            universal_run_name=universal_run_name,
            batch_config=train_batch_config,
            model=model,
            checkpoint_path=save_ckpt_path,
            ckpt_every_n_steps_per_execution=ckpt_every_n_steps_per_execution,
            outfeed_queues=[lr_outfeed_queue, loss_outfeed_queue],
            config=config,
        )
        # Print configs to be logged in wandb's terminal.
        print(config)
        print(f"Training batch config:\n{train_batch_config}")
        # Train the model
        history = model.fit(
            train_dataset,
            steps_per_epoch=train_batch_config.num_micro_batches_per_epoch,
            epochs=num_epochs,
            callbacks=callbacks
        )

    # Evaluate the model on the validation set
    # ========================================
    # Prepare the dataset to be evaluated in the IPU.
    eval_batch_config = BatchConfig(micro_batch_size=micro_batch_size,
                                    total_num_train_samples=num_eval_samples.numpy(),
                                    num_replicas=replicas,
                                    gradient_accumulation_count=grad_acc_steps_per_replica,
                                    dataset_size=num_eval_samples.numpy(),
                                    task=Task.OTHER)
    max_eval_samples = eval_batch_config.micro_batch_size * eval_batch_config.num_micro_batches_per_epoch
    eval_pred_dataset = get_prediction_dataset(eval_dataset, max_eval_samples)
    with strategy.scope():
        # Re-compile the model for prediction if needed.
        if train_batch_config.steps_per_execution != eval_batch_config.steps_per_execution:
            model.compile(steps_per_execution=eval_batch_config.steps_per_execution)
        # Get predictions for the validation data.
        print(f"Running inference:\nGenerating predictions on the validation data...")
        predictions = model.predict(
            eval_pred_dataset,
            batch_size=eval_batch_config.micro_batch_size
        )
    # The predictions for the end position goes first in the model outputs tuple (note the output of model.summary()).
    end_predictions, start_predictions = predictions
    # Match the predictions to answers in the original context.
    # This will also write out the predictions to a json file in the directory given by `output_dir`.
    final_predictions = postprocess_qa_predictions(
        list(raw_datasets["validation"].as_numpy_iterator()),
        list(eval_dataset.unbatch().as_numpy_iterator()),
        (start_predictions, end_predictions),
        output_dir=output_dir
    )
    # Format the predictions and the actual labels as expected by the metric.
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
    formatted_labels = format_raw_data_for_metric(raw_datasets["validation"])
    metric = load_metric("squad")
    metrics = metric.compute(predictions=formatted_predictions, references=formatted_labels)
    print("Evaluation metrics:")
    for key, value in metrics.items():
        print(f"{key}: {value:.3f}")

    return history