Exemple #1
0
def iter_chunk_tokenize_and_featurize(task,
                                      examples: list,
                                      tokenizer,
                                      feat_spec: FeaturizationSpec,
                                      phase,
                                      verbose=False):
    """Generator of DataRows containing tokenized and featurized examples.

    Args:
        task (Task): Task object
        examples (list[Example]): list of task Examples.
        tokenizer: TODO  (Issue #44)
        feat_spec (FeaturizationSpec): Tokenization-related metadata.
        phase (str): string identifying the data subset (e.g., train, val or test).
        verbose: If True, display progress bar.

    Yields:
        DataRow containing tokenized and featurized examples.

    """
    for example in maybe_tqdm(examples, desc="Tokenizing", verbose=verbose):
        # TODO: Better solution  (Issue #48)
        if task.TASK_TYPE == TaskTypes.SQUAD_STYLE_QA:
            yield from example.to_feature_list(
                tokenizer=tokenizer,
                max_seq_length=feat_spec.max_seq_length,
                doc_stride=task.doc_stride,
                max_query_length=task.max_query_length,
                set_type=phase,
            )
        else:
            yield example.tokenize(tokenizer).featurize(tokenizer, feat_spec)
Exemple #2
0
    def run_train_context(self, verbose=True):
        train_dataloader_dict = self.get_train_dataloader_dict()
        loss_weights_dict = {}
        train_state = TrainState.from_task_name_list(
            self.jiant_task_container.task_run_config.train_task_list)
        global_train_config = self.jiant_task_container.global_train_config

        losses = []
        for step in maybe_tqdm(
                range(global_train_config.max_steps),
                desc="Training",
                verbose=verbose,
        ):
            regular_log(logger, step, interval=10, tag='train')

            if step == global_train_config.weighted_sampling_start_step:
                train_dataloader_dict = self.get_train_dataloader_dict(
                    do_weighted_sampling=True)
            if step == global_train_config.weighted_loss_start_step:
                loss_weights_dict = self.get_loss_weights_dict()

            loss_per_step = self.run_train_step(
                train_dataloader_dict=train_dataloader_dict,
                train_state=train_state,
                loss_weights_dict=loss_weights_dict,
            )
            losses.append(loss_per_step)

            if step % 100 == 0:
                logger.info('[train] loss: %f', np.mean(losses))
                self.tf_writer.flush()

            yield train_state
def run_val(
    val_dataloader,
    val_labels,
    jiant_model: JiantModel,
    task,
    device,
    local_rank,
    return_preds=False,
    verbose=True,
):
    # Reminder:
    #   val_dataloader contains mostly PyTorch-relevant info
    #   val_labels might contain more details information needed for full evaluation
    if not local_rank == -1:
        return
    jiant_model.eval()
    total_eval_loss = 0
    nb_eval_steps, nb_eval_examples = 0, 0
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()

    for step, (batch, batch_metadata) in enumerate(
        maybe_tqdm(val_dataloader, desc=f"Eval ({task.name}, Val)", verbose=verbose)
    ):
        batch = batch.to(device)

        with torch.no_grad():
            model_output = wrap_jiant_forward(
                jiant_model=jiant_model, batch=batch, task=task, compute_loss=True,
            )
        batch_logits = model_output.logits.detach().cpu().numpy()
        batch_loss = model_output.loss.mean().item()
        total_eval_loss += batch_loss
        eval_accumulator.update(
            batch_logits=batch_logits,
            batch_loss=batch_loss,
            batch=batch,
            batch_metadata=batch_metadata,
        )

        nb_eval_examples += len(batch)
        nb_eval_steps += 1
    eval_loss = total_eval_loss / nb_eval_steps
    tokenizer = (
        jiant_model.tokenizer
        if not torch_utils.is_data_parallel(jiant_model)
        else jiant_model.module.tokenizer
    )
    output = {
        "accumulator": eval_accumulator,
        "loss": eval_loss,
        "metrics": evaluation_scheme.compute_metrics_from_accumulator(
            task=task, accumulator=eval_accumulator, labels=val_labels, tokenizer=tokenizer,
        ),
    }
    if return_preds:
        output["preds"] = evaluation_scheme.get_preds_from_accumulator(
            task=task, accumulator=eval_accumulator,
        )
    return output
def run_test(test_dataloader, jiant_model: JiantModel, task, device, local_rank, verbose=True):
    if not local_rank == -1:
        return
    jiant_model.eval()
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()

    for step, (batch, batch_metadata) in enumerate(
        maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose)
    ):
        batch = batch.to(device)

        with torch.no_grad():
            model_output = wrap_jiant_forward(
                jiant_model=jiant_model, batch=batch, task=task, compute_loss=False,
            )
        batch_logits = model_output.logits.detach().cpu().numpy()
        eval_accumulator.update(
            batch_logits=batch_logits, batch_loss=0, batch=batch, batch_metadata=batch_metadata,
        )
    return {
        "preds": evaluation_scheme.get_preds_from_accumulator(
            task=task, accumulator=eval_accumulator,
        ),
        "accumulator": eval_accumulator,
    }
Exemple #5
0
def iter_chunk_tokenize_and_featurize(examples: list,
                                      tokenizer,
                                      feat_spec: FeaturizationSpec,
                                      phase,
                                      verbose=False):
    """Generator of DataRows containing tokenized and featurized examples.

    Args:
        examples (list[Example]): list of task Examples.
        tokenizer: TODO  (Issue #44)
        feat_spec (FeaturizationSpec): Tokenization-related metadata.
        phase (str): string identifying the data subset (e.g., train, val or test).
        verbose: If True, display progress bar.

    Yields:
        DataRow containing tokenized and featurized examples.

    """
    for example in maybe_tqdm(examples, desc="Tokenizing", verbose=verbose):
        # TODO: Better solution  (Issue #48)
        if isinstance(example, squad_style.Example):
            # TODO: Expose parameters  (Issue #49)
            yield from example.to_feature_list(
                tokenizer=tokenizer,
                max_seq_length=feat_spec.max_seq_length,
                doc_stride=128,
                max_query_length=64,
                set_type=phase,
            )
        else:
            yield example.tokenize(tokenizer).featurize(tokenizer, feat_spec)
Exemple #6
0
    def read_examples(self, path, set_type):
        input_data = read_json(path, encoding="utf-8")["data"]

        is_training = set_type == PHASE.TRAIN
        examples = []
        data = take_one(input_data)
        for paragraph in maybe_tqdm(data["paragraphs"]):
            for qa in paragraph["qas"]:
                qas_id = qa["id"]
                # Because answers can also come from questions, we're going to abuse notation
                #   slightly and put the entire background+situation+question into the "context"
                #   and leave nothing for the "question"
                question_text = " "
                if self.include_background:
                    context_segments = [
                        paragraph["background"],
                        paragraph["situation"],
                        qa["question"],
                    ]
                else:
                    context_segments = [paragraph["situation"], qa["question"]]
                full_context = " ".join(segment.strip()
                                        for segment in context_segments)

                if is_training:
                    answer = qa["answers"][0]
                    start_position_character = full_context.find(
                        answer["text"])
                    answer_text = answer["text"]
                    answers = []
                else:
                    start_position_character = None
                    answer_text = None
                    answers = [{
                        "text":
                        answer["text"],
                        "answer_start":
                        full_context.find(answer["text"])
                    } for answer in qa["answers"]]

                example = Example(
                    qas_id=qas_id,
                    question_text=question_text,
                    context_text=full_context,
                    answer_text=answer_text,
                    start_position_character=start_position_character,
                    title="",
                    is_impossible=False,
                    answers=answers,
                    background_text=paragraph["background"],
                    situation_text=paragraph["situation"],
                )
                examples.append(example)
        return examples
Exemple #7
0
def tokenize_and_featurize(task,
                           examples: list,
                           tokenizer,
                           feat_spec: FeaturizationSpec,
                           phase,
                           verbose=False):
    """Create list of DataRows containing tokenized and featurized examples.

    Args:
        task (Task): Task object
        examples (list[Example]): list of task Examples.
        tokenizer: TODO  (issue #1188)
        feat_spec (FeaturizationSpec): Tokenization-related metadata.
        phase (str): string identifying the data subset (e.g., train, val or test).
        verbose: If True, display progress bar.

    Returns:
        List DataRows containing tokenized and featurized examples.

    """
    # TODO: Better solution  (issue #1184)
    if task.TASK_TYPE == TaskTypes.SQUAD_STYLE_QA:
        data_rows = []
        for step, example in enumerate(
                maybe_tqdm(examples, desc="Tokenizing", verbose=verbose)):
            regular_log(logger, step)
            data_rows += example.to_feature_list(
                tokenizer=tokenizer,
                max_seq_length=feat_spec.max_seq_length,
                doc_stride=task.doc_stride,
                max_query_length=task.max_query_length,
                set_type=phase,
            )
    else:
        data_rows = []
        for step, example in enumerate(
                maybe_tqdm(examples, desc="Tokenizing", verbose=verbose)):
            regular_log(logger, step)
            data_rows.append(
                example.tokenize(tokenizer).featurize(tokenizer, feat_spec))
    return data_rows
Exemple #8
0
 def run_train_context(self, verbose=True):
     train_dataloader_dict = self.get_train_dataloader_dict()
     train_state = TrainState.from_task_name_list(
         self.jiant_task_container.task_run_config.train_task_list)
     for _ in maybe_tqdm(
             range(self.jiant_task_container.global_train_config.max_steps),
             desc="Training",
             verbose=verbose,
     ):
         self.run_train_step(train_dataloader_dict=train_dataloader_dict,
                             train_state=train_state)
         yield train_state
Exemple #9
0
def generic_read_squad_examples(path: str,
                                set_type: str,
                                example_class: type = dict,
                                read_title: bool = True):

    with open(path, "r", encoding="utf-8") as reader:
        input_data = json.load(reader)["data"]

    is_training = set_type == PHASE.TRAIN
    examples = []
    for step, entry in enumerate(
            maybe_tqdm(input_data, desc="Reading SQuAD Entries")):
        regular_log(logger, step)

        if read_title:
            title = entry["title"]
        else:
            title = "-"
        for paragraph in entry["paragraphs"]:
            context_text = paragraph["context"]
            for qa in paragraph["qas"]:
                qas_id = qa["id"]
                question_text = qa["question"]
                start_position_character = None
                answer_text = None
                answers = []

                if "is_impossible" in qa:
                    is_impossible = qa["is_impossible"]
                else:
                    is_impossible = False

                if not is_impossible:
                    if is_training:
                        answer = qa["answers"][0]
                        answer_text = answer["text"]
                        start_position_character = answer["answer_start"]
                    else:
                        answers = qa["answers"]

                example = example_class(
                    qas_id=qas_id,
                    question_text=question_text,
                    context_text=context_text,
                    answer_text=answer_text,
                    start_position_character=start_position_character,
                    title=title,
                    is_impossible=is_impossible,
                    answers=answers,
                )
                examples.append(example)
    return examples
Exemple #10
0
def tokenize_and_featurize(examples: list,
                           tokenizer,
                           feat_spec: FeaturizationSpec,
                           phase,
                           verbose=False):
    """Create list of DataRows containing tokenized and featurized examples.

    Args:
        examples (list[Example]): list of task Examples.
        tokenizer: TODO  (Issue #44)
        feat_spec (FeaturizationSpec): Tokenization-related metadata.
        phase (str): string identifying the data subset (e.g., train, val or test).
        verbose: If True, display progress bar.

    Returns:
        List DataRows containing tokenized and featurized examples.

    """
    # TODO: Better solution  (Issue #48)
    if isinstance(examples[0], squad_style.Example):
        data_rows = []
        for example in maybe_tqdm(examples, desc="Tokenizing",
                                  verbose=verbose):
            # TODO: Expose parameters  (Issue #49)
            data_rows += example.to_feature_list(
                tokenizer=tokenizer,
                feat_spec=feat_spec,
                max_seq_length=feat_spec.max_seq_length,
                doc_stride=128,
                max_query_length=64,
                set_type=phase,
            )
    else:
        data_rows = [
            example.tokenize(tokenizer).featurize(tokenizer, feat_spec) for
            example in maybe_tqdm(examples, desc="Tokenizing", verbose=verbose)
        ]
    return data_rows
Exemple #11
0
 def resume_train_context(self, train_state, verbose=True):
     train_dataloader_dict = self.get_train_dataloader_dict()
     start_position = train_state.global_steps
     for _ in maybe_tqdm(
         range(start_position, self.jiant_task_container.global_train_config.max_steps),
         desc="Training",
         initial=start_position,
         total=self.jiant_task_container.global_train_config.max_steps,
         verbose=verbose,
     ):
         self.run_train_step(
             train_dataloader_dict=train_dataloader_dict, train_state=train_state
         )
         yield train_state
Exemple #12
0
def get_preds_for_single_tagging_task(task,
                                      test_dataloader,
                                      runner: jiant_runner.JiantRunner,
                                      verbose: str = True):
    """Generate predictions for a single tagging task"""
    jiant_model, device = runner.model, runner.device
    jiant_model.eval()
    test_examples = task.get_test_examples()
    preds_list = []
    example_i = 0
    for step, (batch, batch_metadata) in enumerate(
            maybe_tqdm(test_dataloader,
                       desc=f"Eval ({task.name}, Test)",
                       verbose=verbose)):
        batch = batch.to(device)

        with torch.no_grad():
            model_output = wrap_jiant_forward(
                jiant_model=jiant_model,
                batch=batch,
                task=task,
                compute_loss=False,
            )
        batch_logits = model_output.logits.detach().cpu().numpy()
        label_mask_arr = batch.label_mask.cpu().bool().numpy()
        preds_arr = np.argmax(batch_logits, axis=-1)
        for i in range(len(batch)):
            # noinspection PyUnresolvedReferences
            labels = [
                task.ID_TO_LABEL[class_i]
                for class_i in preds_arr[i][label_mask_arr[i]]
            ]
            if len(labels) == len(test_examples[example_i].tokens):
                this_preds = list(zip(test_examples[example_i].tokens, labels))
            elif len(labels) < len(test_examples[example_i].tokens):
                this_preds = list(zip(test_examples[example_i].tokens, labels))
                this_preds += [
                    (task.LABELS[-1], token)
                    for token in test_examples[example_i].tokens[len(labels):]
                ]
            else:
                raise RuntimeError

            preds_list.append(this_preds)
            example_i += 1
    return preds_list
Exemple #13
0
def smart_truncate(dataset: torch_utils.ListDataset,
                   max_seq_length: int,
                   verbose: bool = False):
    """Truncate data to the length of the longest example in the dataset.

    Args:
        dataset (torch_utils.ListDataset): ListDataset to truncate if possible.
        max_seq_length (int): The maximum total input sequence length.
        verbose (bool): If True, display progress bar tracking truncation progress.

    Returns:
        Tuple[torch_utils.ListDataset, int]: truncated dataset, and length of the longest sequence.

    """
    if "input_mask" not in dataset.data[0]["data_row"].get_fields():
        raise RuntimeError("Smart truncate not supported")
    valid_length_ls = []
    range_idx = np.arange(max_seq_length)
    for datum in dataset.data:
        # TODO: document why reshape and max happen here (for cola this isn't necessary).
        #       (issue #1185)
        indexer = datum["data_row"].input_mask.reshape(-1,
                                                       max_seq_length).max(-2)
        valid_length_ls.append(range_idx[indexer.astype(bool)].max() + 1)
    max_valid_length = max(valid_length_ls)

    if max_valid_length == max_seq_length:
        return dataset, max_seq_length

    new_datum_ls = []
    for step, datum in enumerate(
            maybe_tqdm(dataset.data,
                       desc="Smart truncate data",
                       verbose=verbose)):
        regular_log(logger, step)
        new_datum_ls.append(
            smart_truncate_datum(
                datum=datum,
                max_seq_length=max_seq_length,
                max_valid_length=max_valid_length,
            ))
    new_dataset = torch_utils.ListDataset(new_datum_ls)
    return new_dataset, max_valid_length
Exemple #14
0
    def read_squad_examples(cls, path, set_type):
        with open(path, "r", encoding="utf-8") as reader:
            input_data = json.load(reader)["data"]

        is_training = set_type == PHASE.TRAIN
        examples = []
        for entry in maybe_tqdm(input_data, desc="Reading SQuAD Entries"):
            title = entry["title"]
            for paragraph in entry["paragraphs"]:
                context_text = paragraph["context"]
                for qa in paragraph["qas"]:
                    qas_id = qa["id"]
                    question_text = qa["question"]
                    start_position_character = None
                    answer_text = None
                    answers = []

                    if "is_impossible" in qa:
                        is_impossible = qa["is_impossible"]
                    else:
                        is_impossible = False

                    if not is_impossible:
                        if is_training:
                            answer = qa["answers"][0]
                            answer_text = answer["text"]
                            start_position_character = answer["answer_start"]
                        else:
                            answers = qa["answers"]

                    example = cls.Example(
                        qas_id=qas_id,
                        question_text=question_text,
                        context_text=context_text,
                        answer_text=answer_text,
                        start_position_character=start_position_character,
                        title=title,
                        is_impossible=is_impossible,
                        answers=answers,
                    )
                    examples.append(example)
        return examples
Exemple #15
0
def smart_truncate_cache(
    cache: shared_caching.ChunkedFilesDataCache,
    max_seq_length: int,
    max_valid_length: int,
    verbose: bool = False,
):
    for chunk_i in maybe_trange(cache.num_chunks,
                                desc="Smart truncate chunks",
                                verbose=verbose):
        chunk = torch.load(cache.get_chunk_path(chunk_i))
        new_chunk = []
        for datum in maybe_tqdm(chunk,
                                desc="Smart truncate chunk-datum",
                                verbose=verbose):
            new_chunk.append(
                smart_truncate_datum(
                    datum=datum,
                    max_seq_length=max_seq_length,
                    max_valid_length=max_valid_length,
                ))
        torch.save(new_chunk, cache.get_chunk_path(chunk_i))
Exemple #16
0
def compute_predictions_logits_v2(
    partial_examples: List[PartialExample],
    all_results,
    n_best_size,
    max_answer_length,
    do_lower_case,
    version_2_with_negative,
    null_score_diff_threshold,
    tokenizer,
    skip_get_final_text=False,
    verbose=True,
):
    """Write final predictions to the json file and log-odds of null if needed."""
    unique_id_to_result = {}
    for result in all_results:
        unique_id_to_result[result.unique_id] = result

    _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "PrelimPrediction",
        ["feature_index", "start_index", "end_index", "start_logit", "end_logit"],
    )

    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    scores_diff_json = collections.OrderedDict()

    for example in maybe_tqdm(partial_examples, verbose=verbose):
        features = example.partial_features

        prelim_predictions = []
        # keep track of the minimum score of null start+end of position 0
        score_null = 1000000  # large and positive
        min_null_feature_index = 0  # the paragraph slice with min null score
        null_start_logit = 0  # the start logit at the slice with min null score
        null_end_logit = 0  # the end logit at the slice with min null score
        for (feature_index, feature) in enumerate(features):
            result = unique_id_to_result[feature.unique_id]
            start_indexes = _get_best_indexes(result.start_logits, n_best_size)
            end_indexes = _get_best_indexes(result.end_logits, n_best_size)
            # if we could have irrelevant answers, get the min score of irrelevant
            if version_2_with_negative:
                feature_null_score = result.start_logits[0] + result.end_logits[0]
                if feature_null_score < score_null:
                    score_null = feature_null_score
                    min_null_feature_index = feature_index
                    null_start_logit = result.start_logits[0]
                    null_end_logit = result.end_logits[0]
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # We could hypothetically create invalid predictions, e.g., predict
                    # that the start of the span is in the question. We throw out all
                    # invalid predictions.
                    if start_index >= len(feature.tokens):
                        continue
                    if end_index >= len(feature.tokens):
                        continue
                    if start_index not in feature.token_to_orig_map:
                        continue
                    if end_index not in feature.token_to_orig_map:
                        continue
                    if not feature.token_is_max_context.get(start_index, False):
                        continue
                    if end_index < start_index:
                        continue
                    length = end_index - start_index + 1
                    if length > max_answer_length:
                        continue
                    prelim_predictions.append(
                        _PrelimPrediction(
                            feature_index=feature_index,
                            start_index=start_index,
                            end_index=end_index,
                            start_logit=result.start_logits[start_index],
                            end_logit=result.end_logits[end_index],
                        )
                    )
        if version_2_with_negative:
            prelim_predictions.append(
                _PrelimPrediction(
                    feature_index=min_null_feature_index,
                    start_index=0,
                    end_index=0,
                    start_logit=null_start_logit,
                    end_logit=null_end_logit,
                )
            )
        prelim_predictions = sorted(
            prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True
        )

        _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
            "NbestPrediction", ["text", "start_logit", "end_logit"]
        )

        seen_predictions = {}
        nbest = []
        for pred in prelim_predictions:
            if len(nbest) >= n_best_size:
                break
            feature = features[pred.feature_index]
            if pred.start_index > 0:  # this is a non-null prediction
                tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
                orig_doc_start = feature.token_to_orig_map[pred.start_index]
                orig_doc_end = feature.token_to_orig_map[pred.end_index]
                orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]

                tok_text = tokenizer.convert_tokens_to_string(tok_tokens)

                # tok_text = " ".join(tok_tokens)
                #
                # # De-tokenize WordPieces that have been split off.
                # tok_text = tok_text.replace(" ##", "")
                # tok_text = tok_text.replace("##", "")

                # Clean whitespace
                tok_text = tok_text.strip()
                tok_text = " ".join(tok_text.split())
                orig_text = " ".join(orig_tokens)

                if not skip_get_final_text:
                    final_text = get_final_text(tok_text, orig_text, do_lower_case)
                else:
                    final_text = tok_text
                if final_text in seen_predictions:
                    continue

                seen_predictions[final_text] = True
            else:
                final_text = ""
                seen_predictions[final_text] = True

            nbest.append(
                _NbestPrediction(
                    text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit
                )
            )
        # if we didn't include the empty option in the n-best, include it
        if version_2_with_negative:
            if "" not in seen_predictions:
                nbest.append(
                    _NbestPrediction(
                        text="", start_logit=null_start_logit, end_logit=null_end_logit
                    )
                )

            # In very rare edge cases we could only have single null prediction.
            # So we just create a nonce prediction in this case to avoid failure.
            if len(nbest) == 1:
                nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

        # In very rare edge cases we could have no valid predictions. So we
        # just create a nonce prediction in this case to avoid failure.
        if not nbest:
            nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

        assert len(nbest) >= 1

        total_scores = []
        best_non_null_entry = None
        for entry in nbest:
            total_scores.append(entry.start_logit + entry.end_logit)
            if not best_non_null_entry:
                if entry.text:
                    best_non_null_entry = entry

        probs = _compute_softmax(total_scores)

        nbest_json = []
        for (i, entry) in enumerate(nbest):
            output = collections.OrderedDict()
            output["text"] = entry.text
            output["probability"] = probs[i]
            output["start_logit"] = entry.start_logit
            output["end_logit"] = entry.end_logit
            nbest_json.append(output)

        assert len(nbest_json) >= 1

        if not version_2_with_negative:
            all_predictions[example.qas_id] = nbest_json[0]["text"]
        else:
            # predict "" iff the null score - the score of best non-null > threshold
            score_diff = (
                score_null - best_non_null_entry.start_logit - best_non_null_entry.end_logit
            )
            scores_diff_json[example.qas_id] = score_diff
            if score_diff > null_score_diff_threshold:
                all_predictions[example.qas_id] = ""
            else:
                all_predictions[example.qas_id] = best_non_null_entry.text
        all_nbest_json[example.qas_id] = nbest_json

    return all_predictions
Exemple #17
0
def run_val(
    val_dataloader,
    val_labels,
    jiant_model: JiantModel,
    task,
    device,
    local_rank,
    tf_writer: SummaryWriter,
    global_step: Optional[int] = None,
    phase=None,
    return_preds=False,
    return_logits=True,
    return_encoder_output: bool = False,
    verbose=True,
    split='valid',
):
    # Reminder:
    #   val_dataloader contains mostly PyTorch-relevant info
    #   val_labels might contain more details information needed for full evaluation
    has_labels = True  # TODO: データセットにラベルが存在するかどうかを自動判定する.

    if not local_rank == -1:
        return
    jiant_model.eval()
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()
    total_eval_loss = 0
    nb_eval_steps, nb_eval_examples = 0, 0

    encoder_outputs = []
    for step, (batch, batch_metadata) in enumerate(
            maybe_tqdm(val_dataloader,
                       desc=f"Eval ({task.name}, {str(phase)})",
                       verbose=verbose)):
        regular_log(logger, step, interval=10, tag=split)

        batch = batch.to(device)

        with torch.no_grad():
            model_outputs = wrap_jiant_forward(
                jiant_model=jiant_model,
                batch=batch,
                task=task,
                compute_loss=has_labels,
                get_encoder_output=return_encoder_output,
            )
            if return_encoder_output:
                model_output, encoder_output = model_outputs
                encoder_outputs.append(encoder_output)
            else:
                model_output = model_outputs
        batch_logits = model_output.logits.detach().cpu().numpy()
        if has_labels:
            batch_loss = model_output.loss.mean().item()
        else:
            batch_loss = 0
        total_eval_loss += batch_loss
        eval_accumulator.update(
            batch_logits=batch_logits,
            batch_loss=batch_loss,
            batch=batch,
            batch_metadata=batch_metadata,
        )

        nb_eval_examples += len(batch)
        nb_eval_steps += 1

    eval_loss = total_eval_loss / nb_eval_steps
    output = {
        "accumulator": eval_accumulator,
    }

    if has_labels:
        tokenizer = (jiant_model.tokenizer
                     if not torch_utils.is_data_parallel(jiant_model) else
                     jiant_model.module.tokenizer)
        metrics = evaluation_scheme.compute_metrics_from_accumulator(
            task=task,
            accumulator=eval_accumulator,
            labels=val_labels,
            tokenizer=tokenizer,
        )

        output.update({
            "loss": eval_loss,
            "metrics": metrics,
        })

        if global_step is not None:
            for metric_name, metric_value in metrics.minor.items():
                tf_writer.add_scalar(f'{split}/{metric_name}',
                                     metric_value,
                                     global_step=global_step)

    if return_preds:
        output["preds"] = evaluation_scheme.get_preds_from_accumulator(
            task=task,
            accumulator=eval_accumulator,
        )
        if isinstance(eval_accumulator,
                      evaluate.ConcatenateLogitsAccumulator) and return_logits:
            output["logits"] = eval_accumulator.get_accumulated()
    if return_encoder_output:
        output["encoder_outputs_pooled"] = np.concatenate(
            [encoder_output.pooled for encoder_output in encoder_outputs])
        output["encoder_outputs_unpooled"] = np.concatenate(
            [encoder_output.unpooled for encoder_output in encoder_outputs])
    if global_step is not None:
        tf_writer.add_scalar(f'{split}/loss',
                             eval_loss,
                             global_step=global_step)

    tf_writer.flush()
    return output
Exemple #18
0
def run_test(
    test_dataloader,
    jiant_model: JiantModel,
    task,
    device,
    local_rank,
    verbose=True,
    return_preds=True,
    return_logits=True,
    return_encoder_output: bool = False,
):
    if not local_rank == -1:
        return
    jiant_model.eval()
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()

    encoder_outputs = []
    for step, (batch, batch_metadata) in enumerate(
            maybe_tqdm(test_dataloader,
                       desc=f"Eval ({task.name}, Test)",
                       verbose=verbose)):
        regular_log(logger, step, interval=10, tag='test')

        batch = batch.to(device)

        with torch.no_grad():
            model_outputs = wrap_jiant_forward(
                jiant_model=jiant_model,
                batch=batch,
                task=task,
                compute_loss=False,
                get_encoder_output=return_encoder_output,
            )
            if return_encoder_output:
                model_output, encoder_output = model_outputs
                encoder_outputs.append(encoder_output)
            else:
                model_output = model_outputs
        batch_logits = model_output.logits.detach().cpu().numpy()
        eval_accumulator.update(
            batch_logits=batch_logits,
            batch_loss=0,
            batch=batch,
            batch_metadata=batch_metadata,
        )
    output = {
        "accumulator": eval_accumulator,
    }
    if return_preds:
        output["preds"] = evaluation_scheme.get_preds_from_accumulator(
            task=task,
            accumulator=eval_accumulator,
        )
        if isinstance(eval_accumulator,
                      evaluate.ConcatenateLogitsAccumulator) and return_logits:
            output["logits"] = eval_accumulator.get_accumulated()
    if return_encoder_output:
        output["encoder_outputs_pooled"] = np.concatenate(
            [encoder_output.pooled for encoder_output in encoder_outputs])
        output["encoder_outputs_unpooled"] = np.concatenate(
            [encoder_output.unpooled for encoder_output in encoder_outputs])
    return output