Exemplo n.º 1
0
def iter_chunk_tokenize_and_featurize(task,
                                      examples: list,
                                      tokenizer,
                                      feat_spec: FeaturizationSpec,
                                      phase,
                                      verbose=False):
    """Generator of DataRows containing tokenized and featurized examples.

    Args:
        task (Task): Task object
        examples (list[Example]): list of task Examples.
        tokenizer: TODO  (issue #1188)
        feat_spec (FeaturizationSpec): Tokenization-related metadata.
        phase (str): string identifying the data subset (e.g., train, val or test).
        verbose: If True, display progress bar.

    Yields:
        DataRow containing tokenized and featurized examples.

    """
    for step, example in enumerate(
            maybe_tqdm(examples, desc="Tokenizing", verbose=verbose)):
        regular_log(logger, step)
        # TODO: Better solution  (issue #1184)
        if task.TASK_TYPE == TaskTypes.SQUAD_STYLE_QA:
            yield from example.to_feature_list(
                tokenizer=tokenizer,
                max_seq_length=feat_spec.max_seq_length,
                doc_stride=task.doc_stride,
                max_query_length=task.max_query_length,
                set_type=phase,
            )
        else:
            yield example.tokenize(tokenizer).featurize(tokenizer, feat_spec)
Exemplo n.º 2
0
    def run_train_context(self, verbose=True):
        train_dataloader_dict = self.get_train_dataloader_dict()
        loss_weights_dict = {}
        train_state = TrainState.from_task_name_list(
            self.jiant_task_container.task_run_config.train_task_list)
        global_train_config = self.jiant_task_container.global_train_config

        losses = []
        for step in maybe_tqdm(
                range(global_train_config.max_steps),
                desc="Training",
                verbose=verbose,
        ):
            regular_log(logger, step, interval=10, tag='train')

            if step == global_train_config.weighted_sampling_start_step:
                train_dataloader_dict = self.get_train_dataloader_dict(
                    do_weighted_sampling=True)
            if step == global_train_config.weighted_loss_start_step:
                loss_weights_dict = self.get_loss_weights_dict()

            loss_per_step = self.run_train_step(
                train_dataloader_dict=train_dataloader_dict,
                train_state=train_state,
                loss_weights_dict=loss_weights_dict,
            )
            losses.append(loss_per_step)

            if step % 100 == 0:
                logger.info('[train] loss: %f', np.mean(losses))
                self.tf_writer.flush()

            yield train_state
Exemplo n.º 3
0
def generic_read_squad_examples(path: str,
                                set_type: str,
                                example_class: type = dict,
                                read_title: bool = True):

    with open(path, "r", encoding="utf-8") as reader:
        input_data = json.load(reader)["data"]

    is_training = set_type == PHASE.TRAIN
    examples = []
    for step, entry in enumerate(
            maybe_tqdm(input_data, desc="Reading SQuAD Entries")):
        regular_log(logger, step)

        if read_title:
            title = entry["title"]
        else:
            title = "-"
        for paragraph in entry["paragraphs"]:
            context_text = paragraph["context"]
            for qa in paragraph["qas"]:
                qas_id = qa["id"]
                question_text = qa["question"]
                start_position_character = None
                answer_text = None
                answers = []

                if "is_impossible" in qa:
                    is_impossible = qa["is_impossible"]
                else:
                    is_impossible = False

                if not is_impossible:
                    if is_training:
                        answer = qa["answers"][0]
                        answer_text = answer["text"]
                        start_position_character = answer["answer_start"]
                    else:
                        answers = qa["answers"]

                example = example_class(
                    qas_id=qas_id,
                    question_text=question_text,
                    context_text=context_text,
                    answer_text=answer_text,
                    start_position_character=start_position_character,
                    title=title,
                    is_impossible=is_impossible,
                    answers=answers,
                )
                examples.append(example)
    return examples
Exemplo n.º 4
0
    def load_combined(path):
        result = pd.read_csv(
            path,
            encoding="utf-8",
            dtype=dict(is_answer_absent=float),
            na_values=dict(question=[], story_text=[], validated_answers=[]),
            keep_default_na=False,
        )

        if "story_text" in result.keys():
            for step, row_ in enumerate(display.tqdm(
                result.itertuples(), total=len(result), desc="Adjusting story texts"
            )):
                regular_log(logger, step)

                story_text_ = row_.story_text.replace("\r\n", "\n")
                result.at[row_.Index, "story_text"] = story_text_

        return result
Exemplo n.º 5
0
def smart_truncate(dataset: torch_utils.ListDataset,
                   max_seq_length: int,
                   verbose: bool = False):
    """Truncate data to the length of the longest example in the dataset.

    Args:
        dataset (torch_utils.ListDataset): ListDataset to truncate if possible.
        max_seq_length (int): The maximum total input sequence length.
        verbose (bool): If True, display progress bar tracking truncation progress.

    Returns:
        Tuple[torch_utils.ListDataset, int]: truncated dataset, and length of the longest sequence.

    """
    if "input_mask" not in dataset.data[0]["data_row"].get_fields():
        raise RuntimeError("Smart truncate not supported")
    valid_length_ls = []
    range_idx = np.arange(max_seq_length)
    for datum in dataset.data:
        # TODO: document why reshape and max happen here (for cola this isn't necessary).
        #       (issue #1185)
        indexer = datum["data_row"].input_mask.reshape(-1,
                                                       max_seq_length).max(-2)
        valid_length_ls.append(range_idx[indexer.astype(bool)].max() + 1)
    max_valid_length = max(valid_length_ls)

    if max_valid_length == max_seq_length:
        return dataset, max_seq_length

    new_datum_ls = []
    for step, datum in enumerate(
            maybe_tqdm(dataset.data,
                       desc="Smart truncate data",
                       verbose=verbose)):
        regular_log(logger, step)
        new_datum_ls.append(
            smart_truncate_datum(
                datum=datum,
                max_seq_length=max_seq_length,
                max_valid_length=max_valid_length,
            ))
    new_dataset = torch_utils.ListDataset(new_datum_ls)
    return new_dataset, max_valid_length
Exemplo n.º 6
0
def preprocess_all_glue_data(input_base_path,
                             output_base_path,
                             task_name_ls=None):
    if task_name_ls is None:
        task_name_ls = GLUE_CONVERSION.keys()
    os.makedirs(output_base_path, exist_ok=True)
    os.makedirs(os.path.join(output_base_path, "data"), exist_ok=True)
    os.makedirs(os.path.join(output_base_path, "configs"), exist_ok=True)
    for step, task_name in enumerate(tqdm.tqdm(task_name_ls)):
        regular_log(logger, step)

        task_data_path = os.path.join(output_base_path, "data", task_name)
        paths_dict = convert_glue_data(
            input_base_path=input_base_path,
            task_data_path=task_data_path,
            task_name=task_name,
        )
        config = {"task": task_name, "paths": paths_dict, "name": task_name}
        py_io.write_json(data=config,
                         path=os.path.join(output_base_path, "configs",
                                           f"{task_name}.json"))
Exemplo n.º 7
0
def get_preds_for_single_tagging_task(
    task, test_dataloader, runner: jiant_runner.JiantRunner, verbose: str = True
):
    """Generate predictions for a single tagging task"""
    jiant_model, device = runner.model, runner.device
    jiant_model.eval()
    test_examples = task.get_test_examples()
    preds_list = []
    example_i = 0
    for step, (batch, batch_metadata) in enumerate(
        maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose)
    ):
        regular_log(logger, step, interval=10)

        batch = batch.to(device)

        with torch.no_grad():
            model_output = wrap_jiant_forward(
                jiant_model=jiant_model, batch=batch, task=task, compute_loss=False,
            )
        batch_logits = model_output.logits.detach().cpu().numpy()
        label_mask_arr = batch.label_mask.cpu().bool().numpy()
        preds_arr = np.argmax(batch_logits, axis=-1)
        for i in range(len(batch)):
            # noinspection PyUnresolvedReferences
            labels = [task.ID_TO_LABEL[class_i] for class_i in preds_arr[i][label_mask_arr[i]]]
            if len(labels) == len(test_examples[example_i].tokens):
                this_preds = list(zip(test_examples[example_i].tokens, labels))
            elif len(labels) < len(test_examples[example_i].tokens):
                this_preds = list(zip(test_examples[example_i].tokens, labels))
                this_preds += [
                    (task.LABELS[-1], token)
                    for token in test_examples[example_i].tokens[len(labels) :]
                ]
            else:
                raise RuntimeError

            preds_list.append(this_preds)
            example_i += 1
    return preds_list
Exemplo n.º 8
0
def smart_truncate_cache(
    cache: shared_caching.ChunkedFilesDataCache,
    max_seq_length: int,
    max_valid_length: int,
    verbose: bool = False,
):
    for chunk_i in maybe_trange(cache.num_chunks,
                                desc="Smart truncate chunks",
                                verbose=verbose):
        chunk = torch.load(cache.get_chunk_path(chunk_i))
        new_chunk = []
        for step, datum in enumerate(
                maybe_tqdm(chunk,
                           desc="Smart truncate chunk-datum",
                           verbose=verbose)):
            regular_log(logger, step)
            new_chunk.append(
                smart_truncate_datum(
                    datum=datum,
                    max_seq_length=max_seq_length,
                    max_valid_length=max_valid_length,
                ))
        torch.save(new_chunk, cache.get_chunk_path(chunk_i))
Exemplo n.º 9
0
def download_newsqa_data_and_write_config(
    task_name: str, task_data_path: str, task_config_path: str
):
    def get_consensus_answer(row_):
        answer_char_start, answer_char_end = None, None
        if row_.validated_answers:
            validated_answers_ = json.loads(row.validated_answers)
            answer_, max_count = max(validated_answers_.items(), key=itemgetter(1))
            total_count = sum(validated_answers_.values())
            if max_count >= total_count / 2.0:
                if answer_ != "none" and answer_ != "bad_question":
                    answer_char_start, answer_char_end = map(int, answer_.split(":"))
                else:
                    # No valid answer.
                    pass
        else:
            # Check row_.answer_char_ranges for most common answer.
            # No validation was done so there must be an answer with consensus.
            answers = Counter()
            for user_answer in row_.answer_char_ranges.split("|"):
                for ans in user_answer.split(","):
                    answers[ans] += 1
            top_answer = answers.most_common(1)
            if top_answer:
                top_answer, _ = top_answer[0]
                if ":" in top_answer:
                    answer_char_start, answer_char_end = map(int, top_answer.split(":"))

        return answer_char_start, answer_char_end

    def load_combined(path):
        result = pd.read_csv(
            path,
            encoding="utf-8",
            dtype=dict(is_answer_absent=float),
            na_values=dict(question=[], story_text=[], validated_answers=[]),
            keep_default_na=False,
        )

        if "story_text" in result.keys():
            for step, row_ in enumerate(display.tqdm(
                result.itertuples(), total=len(result), desc="Adjusting story texts"
            )):
                regular_log(logger, step)

                story_text_ = row_.story_text.replace("\r\n", "\n")
                result.at[row_.Index, "story_text"] = story_text_

        return result

    def _map_answers(answers):
        result = []
        for a in answers.split("|"):
            user_answers = []
            result.append(dict(sourcerAnswers=user_answers))
            for r in a.split(","):
                if r == "None":
                    user_answers.append(dict(noAnswer=True))
                else:
                    start_, end_ = map(int, r.split(":"))
                    user_answers.append(dict(s=start_, e=end_))
        return result

    def strip_empty_strings(strings):
        while strings and strings[-1] == "":
            del strings[-1]
        return strings

    # Require: cnn_stories.tgz
    cnn_stories_path = os.path.join(task_data_path, "cnn_stories.tgz")
    assert os.path.exists(cnn_stories_path), (
        "Download CNN Stories from https://cs.nyu.edu/~kcho/DMQA/ and save to " + cnn_stories_path
    )
    # Require: newsqa-data-v1/newsqa-data-v1.csv
    dataset_path = os.path.join(task_data_path, "newsqa-data-v1", "newsqa-data-v1.csv")
    if os.path.exists(dataset_path):
        pass
    elif os.path.exists(os.path.join(task_data_path, "newsqa-data-v1.zip")):
        download_utils.unzip_file(
            zip_path=os.path.join(task_data_path, "newsqa-data-v1.zip"),
            extract_location=task_data_path,
            delete=False,
        )
    else:
        raise AssertionError(
            "Download https://www.microsoft.com/en-us/research/project/newsqa-dataset/#!download"
            " and save to " + os.path.join(task_data_path, "newsqa-data-v1.zip")
        )

    # Download auxiliary data
    os.makedirs(task_data_path, exist_ok=True)
    file_name_list = [
        "train_story_ids.csv",
        "dev_story_ids.csv",
        "test_story_ids.csv",
        "stories_requiring_extra_newline.csv",
        "stories_requiring_two_extra_newlines.csv",
        "stories_to_decode_specially.csv",
    ]
    for file_name in file_name_list:
        download_utils.download_file(
            f"https://raw.githubusercontent.com/Maluuba/newsqa/master/maluuba/newsqa/{file_name}",
            os.path.join(task_data_path, file_name),
        )

    dataset = load_combined(dataset_path)
    remaining_story_ids = set(dataset["story_id"])
    with open(
        os.path.join(task_data_path, "stories_requiring_extra_newline.csv"), "r", encoding="utf-8"
    ) as f:
        stories_requiring_extra_newline = set(f.read().split("\n"))

    with open(
        os.path.join(task_data_path, "stories_requiring_two_extra_newlines.csv"),
        "r",
        encoding="utf-8",
    ) as f:
        stories_requiring_two_extra_newlines = set(f.read().split("\n"))

    with open(
        os.path.join(task_data_path, "stories_to_decode_specially.csv"), "r", encoding="utf-8"
    ) as f:
        stories_to_decode_specially = set(f.read().split("\n"))

    # Start combining data files
    story_id_to_text = {}
    with tarfile.open(cnn_stories_path, mode="r:gz", encoding="utf-8") as t:
        highlight_indicator = "@highlight"

        copyright_line_pattern = re.compile(
            "^(Copyright|Entire contents of this article copyright, )"
        )
        with display.tqdm(total=len(remaining_story_ids), desc="Getting story texts") as pbar:
            for step, member in enumerate(t.getmembers()):
                regular_log(logger, step)

                story_id = member.name
                if story_id in remaining_story_ids:
                    remaining_story_ids.remove(story_id)
                    story_file = t.extractfile(member)

                    # Correct discrepancies in stories.
                    # Problems are caused by using several programming languages and libraries.
                    # When ingesting the stories, we started with Python 2.
                    # After dealing with unicode issues, we tried switching to Python 3.
                    # That caused inconsistency problems so we switched back to Python 2.
                    # Furthermore, when crowdsourcing, JavaScript and HTML templating perturbed
                    # the stories.
                    # So here we map the text to be compatible with the indices.
                    lines = map(lambda s_: s_.strip().decode("utf-8"), story_file.readlines())

                    story_file.close()
                    lines = list(lines)
                    highlights_start = lines.index(highlight_indicator)
                    story_lines = lines[:highlights_start]
                    story_lines = strip_empty_strings(story_lines)
                    while len(story_lines) > 1 and copyright_line_pattern.search(story_lines[-1]):
                        story_lines = strip_empty_strings(story_lines[:-2])
                    if story_id in stories_requiring_two_extra_newlines:
                        story_text = "\n\n\n".join(story_lines)
                    elif story_id in stories_requiring_extra_newline:
                        story_text = "\n\n".join(story_lines)
                    else:
                        story_text = "\n".join(story_lines)

                    story_text = story_text.replace("\xe2\x80\xa2", "\xe2\u20ac\xa2")
                    story_text = story_text.replace("\xe2\x82\xac", "\xe2\u201a\xac")
                    story_text = story_text.replace("\r", "\n")
                    if story_id in stories_to_decode_specially:
                        story_text = story_text.replace("\xe9", "\xc3\xa9")
                    story_id_to_text[story_id] = story_text

                    pbar.update()

                    if len(remaining_story_ids) == 0:
                        break

    for step, row in enumerate(display.tqdm(dataset.itertuples(), total=len(dataset), desc="Setting story texts")):
        regular_log(logger, step)

        # Set story_text since we cannot include it in the dataset.
        story_text = story_id_to_text[row.story_id]
        dataset.at[row.Index, "story_text"] = story_text

        # Handle endings that are too large.
        answer_char_ranges = row.answer_char_ranges.split("|")
        updated_answer_char_ranges = []
        ranges_updated = False
        for user_answer_char_ranges in answer_char_ranges:
            updated_user_answer_char_ranges = []
            for char_range in user_answer_char_ranges.split(","):
                if char_range != "None":
                    start, end = map(int, char_range.split(":"))
                    if end > len(story_text):
                        ranges_updated = True
                        end = len(story_text)
                    if start < end:
                        updated_user_answer_char_ranges.append("%d:%d" % (start, end))
                    else:
                        # It's unclear why but sometimes the end is after the start.
                        # We'll filter these out.
                        ranges_updated = True
                else:
                    updated_user_answer_char_ranges.append(char_range)
            if updated_user_answer_char_ranges:
                updated_user_answer_char_ranges = ",".join(updated_user_answer_char_ranges)
                updated_answer_char_ranges.append(updated_user_answer_char_ranges)
        if ranges_updated:
            updated_answer_char_ranges = "|".join(updated_answer_char_ranges)
            dataset.at[row.Index, "answer_char_ranges"] = updated_answer_char_ranges

        if row.validated_answers and not pd.isnull(row.validated_answers):
            updated_validated_answers = {}
            validated_answers = json.loads(row.validated_answers)
            for char_range, count in validated_answers.items():
                if ":" in char_range:
                    start, end = map(int, char_range.split(":"))
                    if end > len(story_text):
                        ranges_updated = True
                        end = len(story_text)
                    if start < end:
                        char_range = "{}:{}".format(start, end)
                        updated_validated_answers[char_range] = count
                    else:
                        # It's unclear why but sometimes the end is after the start.
                        # We'll filter these out.
                        ranges_updated = True
                else:
                    updated_validated_answers[char_range] = count
            if ranges_updated:
                updated_validated_answers = json.dumps(
                    updated_validated_answers, ensure_ascii=False, separators=(",", ":")
                )
                dataset.at[row.Index, "validated_answers"] = updated_validated_answers

    # Process Splits
    data = []
    cache = dict()

    train_story_ids = set(
        pd.read_csv(os.path.join(task_data_path, "train_story_ids.csv"))["story_id"].values
    )
    dev_story_ids = set(
        pd.read_csv(os.path.join(task_data_path, "dev_story_ids.csv"))["story_id"].values
    )
    test_story_ids = set(
        pd.read_csv(os.path.join(task_data_path, "test_story_ids.csv"))["story_id"].values
    )

    def _get_data_type(story_id_):
        if story_id_ in train_story_ids:
            return "train"
        elif story_id_ in dev_story_ids:
            return "dev"
        elif story_id_ in test_story_ids:
            return "test"
        else:
            return ValueError("{} not found in any story ID set.".format(story_id))

    for step, row in enumerate(display.tqdm(dataset.itertuples(), total=len(dataset), desc="Building json")):
        regular_log(logger, step)

        questions = cache.get(row.story_id)
        if questions is None:
            questions = []
            datum = dict(
                storyId=row.story_id,
                type=_get_data_type(row.story_id),
                text=row.story_text,
                questions=questions,
            )
            cache[row.story_id] = questions
            data.append(datum)
        q = dict(
            q=row.question,
            answers=_map_answers(row.answer_char_ranges),
            isAnswerAbsent=row.is_answer_absent,
        )
        if row.is_question_bad != "?":
            q["isQuestionBad"] = float(row.is_question_bad)
        if row.validated_answers and not pd.isnull(row.validated_answers):
            validated_answers = json.loads(row.validated_answers)
            q["validatedAnswers"] = []
            for answer, count in validated_answers.items():
                answer_item = dict(count=count)
                if answer == "none":
                    answer_item["noAnswer"] = True
                elif answer == "bad_question":
                    answer_item["badQuestion"] = True
                else:
                    s, e = map(int, answer.split(":"))
                    answer_item["s"] = s
                    answer_item["e"] = e
                q["validatedAnswers"].append(answer_item)
        consensus_start, consensus_end = get_consensus_answer(row)
        if consensus_start is None and consensus_end is None:
            if q.get("isQuestionBad", 0) >= 0.5:
                q["consensus"] = dict(badQuestion=True)
            else:
                q["consensus"] = dict(noAnswer=True)
        else:
            q["consensus"] = dict(s=consensus_start, e=consensus_end)
        questions.append(q)

    phase_dict = {
        "train": [],
        "val": [],
        "test": [],
    }
    phase_map = {"train": "train", "dev": "val", "test": "test"}
    for entry in data:
        phase = phase_map[entry["type"]]
        output_entry = {"text": entry["text"], "storyId": entry["storyId"], "qas": []}
        for qn in entry["questions"]:
            if "badQuestion" in qn["consensus"] or "noAnswer" in qn["consensus"]:
                continue
            output_entry["qas"].append({"question": qn["q"], "answer": qn["consensus"]})
        phase_dict[phase].append(output_entry)
    for phase, phase_data in phase_dict.items():
        py_io.write_jsonl(phase_data, os.path.join(task_data_path, f"{phase}.jsonl"), skip_if_exists=True)
    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "train": os.path.join(task_data_path, "train.jsonl"),
                "val": os.path.join(task_data_path, "val.jsonl"),
                "test": os.path.join(task_data_path, "val.jsonl"),
            },
            "name": task_name,
        },
        path=task_config_path,
        skip_if_exists=True,
    )
    for file_name in file_name_list:
        os.remove(os.path.join(task_data_path, file_name))
Exemplo n.º 10
0
def compute_predictions_logits_v2(
    partial_examples: List[PartialExample],
    all_results,
    n_best_size,
    max_answer_length,
    do_lower_case,
    version_2_with_negative,
    null_score_diff_threshold,
    tokenizer,
    skip_get_final_text=False,
    verbose=True,
):
    """Write final predictions to the json file and log-odds of null if needed."""
    unique_id_to_result = {}
    for result in all_results:
        unique_id_to_result[result.unique_id] = result

    _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "PrelimPrediction",
        ["feature_index", "start_index", "end_index", "start_logit", "end_logit"],
    )

    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    scores_diff_json = collections.OrderedDict()

    for step, example in enumerate(maybe_tqdm(partial_examples, verbose=verbose)):
        regular_log(logger, step, interval=10)

        features = example.partial_features

        prelim_predictions = []
        # keep track of the minimum score of null start+end of position 0
        score_null = 1000000  # large and positive
        min_null_feature_index = 0  # the paragraph slice with min null score
        null_start_logit = 0  # the start logit at the slice with min null score
        null_end_logit = 0  # the end logit at the slice with min null score
        for (feature_index, feature) in enumerate(features):
            result = unique_id_to_result[feature.unique_id]
            start_indexes = _get_best_indexes(result.start_logits, n_best_size)
            end_indexes = _get_best_indexes(result.end_logits, n_best_size)
            # if we could have irrelevant answers, get the min score of irrelevant
            if version_2_with_negative:
                feature_null_score = result.start_logits[0] + result.end_logits[0]
                if feature_null_score < score_null:
                    score_null = feature_null_score
                    min_null_feature_index = feature_index
                    null_start_logit = result.start_logits[0]
                    null_end_logit = result.end_logits[0]
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # We could hypothetically create invalid predictions, e.g., predict
                    # that the start of the span is in the question. We throw out all
                    # invalid predictions.
                    if start_index >= len(feature.tokens):
                        continue
                    if end_index >= len(feature.tokens):
                        continue
                    if start_index not in feature.token_to_orig_map:
                        continue
                    if end_index not in feature.token_to_orig_map:
                        continue
                    if not feature.token_is_max_context.get(start_index, False):
                        continue
                    if end_index < start_index:
                        continue
                    length = end_index - start_index + 1
                    if length > max_answer_length:
                        continue
                    prelim_predictions.append(
                        _PrelimPrediction(
                            feature_index=feature_index,
                            start_index=start_index,
                            end_index=end_index,
                            start_logit=result.start_logits[start_index],
                            end_logit=result.end_logits[end_index],
                        )
                    )
        if version_2_with_negative:
            prelim_predictions.append(
                _PrelimPrediction(
                    feature_index=min_null_feature_index,
                    start_index=0,
                    end_index=0,
                    start_logit=null_start_logit,
                    end_logit=null_end_logit,
                )
            )
        prelim_predictions = sorted(
            prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True
        )

        _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
            "NbestPrediction", ["text", "start_logit", "end_logit"]
        )

        seen_predictions = {}
        nbest = []
        for pred in prelim_predictions:
            if len(nbest) >= n_best_size:
                break
            feature = features[pred.feature_index]
            if pred.start_index > 0:  # this is a non-null prediction
                tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
                orig_doc_start = feature.token_to_orig_map[pred.start_index]
                orig_doc_end = feature.token_to_orig_map[pred.end_index]
                orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]

                tok_text = tokenizer.convert_tokens_to_string(tok_tokens)

                # tok_text = " ".join(tok_tokens)
                #
                # # De-tokenize WordPieces that have been split off.
                # tok_text = tok_text.replace(" ##", "")
                # tok_text = tok_text.replace("##", "")

                # Clean whitespace
                tok_text = tok_text.strip()
                tok_text = " ".join(tok_text.split())
                orig_text = " ".join(orig_tokens)

                if not skip_get_final_text:
                    final_text = get_final_text(tok_text, orig_text, do_lower_case)
                else:
                    final_text = tok_text
                if final_text in seen_predictions:
                    continue

                seen_predictions[final_text] = True
            else:
                final_text = ""
                seen_predictions[final_text] = True

            nbest.append(
                _NbestPrediction(
                    text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit
                )
            )
        # if we didn't include the empty option in the n-best, include it
        if version_2_with_negative:
            if "" not in seen_predictions:
                nbest.append(
                    _NbestPrediction(
                        text="", start_logit=null_start_logit, end_logit=null_end_logit
                    )
                )

            # In very rare edge cases we could only have single null prediction.
            # So we just create a nonce prediction in this case to avoid failure.
            if len(nbest) == 1:
                nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

        # In very rare edge cases we could have no valid predictions. So we
        # just create a nonce prediction in this case to avoid failure.
        if not nbest:
            nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

        assert len(nbest) >= 1

        total_scores = []
        best_non_null_entry = None
        for entry in nbest:
            total_scores.append(entry.start_logit + entry.end_logit)
            if not best_non_null_entry:
                if entry.text:
                    best_non_null_entry = entry

        probs = _compute_softmax(total_scores)

        nbest_json = []
        for (i, entry) in enumerate(nbest):
            output = collections.OrderedDict()
            output["text"] = entry.text
            output["probability"] = probs[i]
            output["start_logit"] = entry.start_logit
            output["end_logit"] = entry.end_logit
            nbest_json.append(output)

        assert len(nbest_json) >= 1

        if not version_2_with_negative:
            all_predictions[example.qas_id] = nbest_json[0]["text"]
        else:
            # predict "" iff the null score - the score of best non-null > threshold
            score_diff = (
                score_null - best_non_null_entry.start_logit - best_non_null_entry.end_logit
            )
            scores_diff_json[example.qas_id] = score_diff
            if score_diff > null_score_diff_threshold:
                all_predictions[example.qas_id] = ""
            else:
                all_predictions[example.qas_id] = best_non_null_entry.text
        all_nbest_json[example.qas_id] = nbest_json

    return all_predictions
Exemplo n.º 11
0
def run_test(
    test_dataloader,
    jiant_model: JiantModel,
    task,
    device,
    local_rank,
    verbose=True,
    return_preds=True,
    return_logits=True,
    return_encoder_output: bool = False,
):
    if not local_rank == -1:
        return
    jiant_model.eval()
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()

    encoder_outputs = []
    for step, (batch, batch_metadata) in enumerate(
            maybe_tqdm(test_dataloader,
                       desc=f"Eval ({task.name}, Test)",
                       verbose=verbose)):
        regular_log(logger, step, interval=10, tag='test')

        batch = batch.to(device)

        with torch.no_grad():
            model_outputs = wrap_jiant_forward(
                jiant_model=jiant_model,
                batch=batch,
                task=task,
                compute_loss=False,
                get_encoder_output=return_encoder_output,
            )
            if return_encoder_output:
                model_output, encoder_output = model_outputs
                encoder_outputs.append(encoder_output)
            else:
                model_output = model_outputs
        batch_logits = model_output.logits.detach().cpu().numpy()
        eval_accumulator.update(
            batch_logits=batch_logits,
            batch_loss=0,
            batch=batch,
            batch_metadata=batch_metadata,
        )
    output = {
        "accumulator": eval_accumulator,
    }
    if return_preds:
        output["preds"] = evaluation_scheme.get_preds_from_accumulator(
            task=task,
            accumulator=eval_accumulator,
        )
        if isinstance(eval_accumulator,
                      evaluate.ConcatenateLogitsAccumulator) and return_logits:
            output["logits"] = eval_accumulator.get_accumulated()
    if return_encoder_output:
        output["encoder_outputs_pooled"] = np.concatenate(
            [encoder_output.pooled for encoder_output in encoder_outputs])
        output["encoder_outputs_unpooled"] = np.concatenate(
            [encoder_output.unpooled for encoder_output in encoder_outputs])
    return output
Exemplo n.º 12
0
def run_val(
    val_dataloader,
    val_labels,
    jiant_model: JiantModel,
    task,
    device,
    local_rank,
    tf_writer: SummaryWriter,
    global_step: Optional[int] = None,
    phase=None,
    return_preds=False,
    return_logits=True,
    return_encoder_output: bool = False,
    verbose=True,
    split='valid',
):
    # Reminder:
    #   val_dataloader contains mostly PyTorch-relevant info
    #   val_labels might contain more details information needed for full evaluation
    has_labels = True  # TODO: データセットにラベルが存在するかどうかを自動判定する.

    if not local_rank == -1:
        return
    jiant_model.eval()
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()
    total_eval_loss = 0
    nb_eval_steps, nb_eval_examples = 0, 0

    encoder_outputs = []
    for step, (batch, batch_metadata) in enumerate(
            maybe_tqdm(val_dataloader,
                       desc=f"Eval ({task.name}, {str(phase)})",
                       verbose=verbose)):
        regular_log(logger, step, interval=10, tag=split)

        batch = batch.to(device)

        with torch.no_grad():
            model_outputs = wrap_jiant_forward(
                jiant_model=jiant_model,
                batch=batch,
                task=task,
                compute_loss=has_labels,
                get_encoder_output=return_encoder_output,
            )
            if return_encoder_output:
                model_output, encoder_output = model_outputs
                encoder_outputs.append(encoder_output)
            else:
                model_output = model_outputs
        batch_logits = model_output.logits.detach().cpu().numpy()
        if has_labels:
            batch_loss = model_output.loss.mean().item()
        else:
            batch_loss = 0
        total_eval_loss += batch_loss
        eval_accumulator.update(
            batch_logits=batch_logits,
            batch_loss=batch_loss,
            batch=batch,
            batch_metadata=batch_metadata,
        )

        nb_eval_examples += len(batch)
        nb_eval_steps += 1

    eval_loss = total_eval_loss / nb_eval_steps
    output = {
        "accumulator": eval_accumulator,
    }

    if has_labels:
        tokenizer = (jiant_model.tokenizer
                     if not torch_utils.is_data_parallel(jiant_model) else
                     jiant_model.module.tokenizer)
        metrics = evaluation_scheme.compute_metrics_from_accumulator(
            task=task,
            accumulator=eval_accumulator,
            labels=val_labels,
            tokenizer=tokenizer,
        )

        output.update({
            "loss": eval_loss,
            "metrics": metrics,
        })

        if global_step is not None:
            for metric_name, metric_value in metrics.minor.items():
                tf_writer.add_scalar(f'{split}/{metric_name}',
                                     metric_value,
                                     global_step=global_step)

    if return_preds:
        output["preds"] = evaluation_scheme.get_preds_from_accumulator(
            task=task,
            accumulator=eval_accumulator,
        )
        if isinstance(eval_accumulator,
                      evaluate.ConcatenateLogitsAccumulator) and return_logits:
            output["logits"] = eval_accumulator.get_accumulated()
    if return_encoder_output:
        output["encoder_outputs_pooled"] = np.concatenate(
            [encoder_output.pooled for encoder_output in encoder_outputs])
        output["encoder_outputs_unpooled"] = np.concatenate(
            [encoder_output.unpooled for encoder_output in encoder_outputs])
    if global_step is not None:
        tf_writer.add_scalar(f'{split}/loss',
                             eval_loss,
                             global_step=global_step)

    tf_writer.flush()
    return output