Exemplo n.º 1
0
def _select_span_with_token(
    text: str, tensorizer: Tensorizer, token_str: str = "[START_ENT]"
) -> T:
    id = tensorizer.get_token_id(token_str)
    query_tensor = tensorizer.text_to_tensor(text)

    if id not in query_tensor:
        query_tensor_full = tensorizer.text_to_tensor(text, apply_max_len=False)
        token_indexes = (query_tensor_full == id).nonzero()
        if token_indexes.size(0) > 0:
            start_pos = token_indexes[0, 0].item()
            # add some randomization to avoid overfitting to a specific token position

            left_shit = int(tensorizer.max_length / 2)
            rnd_shift = int((rnd.random() - 0.5) * left_shit / 2)
            left_shit += rnd_shift

            query_tensor = query_tensor_full[start_pos - left_shit :]
            cls_id = tensorizer.tokenizer.cls_token_id
            if query_tensor[0] != cls_id:
                query_tensor = torch.cat([torch.tensor([cls_id]), query_tensor], dim=0)

            from dpr.models.reader import _pad_to_len

            query_tensor = _pad_to_len(
                query_tensor, tensorizer.get_pad_id(), tensorizer.max_length
            )
            query_tensor[-1] = tensorizer.tokenizer.sep_token_id
            # logger.info('aligned query_tensor %s', query_tensor)

            assert id in query_tensor, "query_tensor={}".format(query_tensor)
            return query_tensor
        else:
            raise RuntimeError(
                "[START_ENT] toke not found for Entity Linking sample query={}".format(
                    text
                )
            )
    else:
        return query_tensor
Exemplo n.º 2
0
def _create_question_passages_tensors(
    wiki_data: TokenizedWikipediaPassages,
    question_token_ids: np.ndarray,
    tensorizer: Tensorizer,
    positives: List[ReaderPassage],
    negatives: List[ReaderPassage],
    total_size: int,
    empty_ids: T,
    max_n_answers: int,
    is_train: bool,
    is_random: bool = True
):
    max_len = empty_ids.size(0)
    pad_token_id = tensorizer.get_pad_id()

    if is_train:
        # select just one positive
        positive_idx = _get_positive_idx(positives, max_len, is_random)
        if positive_idx is None:
            return None

        positive = positives[positive_idx]

        if getattr(positive, "sequence_ids", None) is None:
            # Load in passage tokens and title tokens
            positive.load_tokens(
                question_token_ids=question_token_ids,
                **wiki_data.get_tokenized_data(int(positive.id))
            )
            sequence_ids, passage_offset = tensorizer.concatenate_inputs({
                "question": positive.question_token_ids,
                "passage_title": positive.title_token_ids,
                "passage": positive.passage_token_ids,
            }, get_passage_offset=True)

            positive.sequence_ids = sequence_ids
            positive.passage_offset = passage_offset
            positive.answers_spans = [
                (start + passage_offset, end + passage_offset)
                for start, end in positive.answers_spans
            ]

        positive_a_spans = _get_answer_spans(positive_idx, positives, max_len)[0: max_n_answers]

        answer_starts = [span[0] for span in positive_a_spans]
        answer_ends = [span[1] for span in positive_a_spans]

        assert all(s < max_len for s in answer_starts)
        assert all(e < max_len for e in answer_ends)

        positive_input_ids = tensorizer.to_max_length(positive.sequence_ids.numpy(), apply_max_len=True)
        positive_input_ids = torch.from_numpy(positive_input_ids)

        answer_starts_tensor = torch.zeros((total_size, max_n_answers)).long()
        answer_starts_tensor[0, 0:len(answer_starts)] = torch.tensor(answer_starts)  # only first passage contains the answer

        answer_ends_tensor = torch.zeros((total_size, max_n_answers)).long()
        answer_ends_tensor[0, 0:len(answer_ends)] = torch.tensor(answer_ends)  # only first passage contains the answer

        answer_mask = torch.zeros((total_size, max_n_answers), dtype=torch.long)
        answer_mask[0, 0:len(answer_starts)] = torch.tensor([1 for _ in range(len(answer_starts))])

        positives_IDs: List[int] = [positive.id]
        positives_selected = [positive_input_ids]

    else:
        positives_IDs: List[int] = []
        positives_selected = []
        answer_starts_tensor = None
        answer_ends_tensor = None
        answer_mask = None

    positives_num = len(positives_selected)
    negative_idxs = np.random.permutation(range(len(negatives))) if is_random else range(
        len(negatives) - positives_num)

    negative_idxs = negative_idxs[:total_size - positives_num]

    negatives_IDs: List[int] = []
    negatives_selected = []

    for negative_idx in negative_idxs:
        negative = negatives[negative_idx]

        if getattr(negative, "sequence_ids", None) is None:
            # Load in passage tokens and title tokens
            negative.load_tokens(
                question_token_ids=question_token_ids,
                **wiki_data.get_tokenized_data(int(negative.id))
            )
            # Concatenate input tokens
            sequence_ids, passage_offset = tensorizer.concatenate_inputs({
                "question": negative.question_token_ids,
                "passage_title": negative.title_token_ids,
                "passage": negative.passage_token_ids,
            }, get_passage_offset=True)
            negative.sequence_ids = sequence_ids
            negative.passage_offset = passage_offset

        negatives_IDs.append(negative.id)

        negative_input_ids = tensorizer.to_max_length(negative.sequence_ids.numpy(), apply_max_len=True)
        negatives_selected.append(torch.from_numpy(negative_input_ids))

    while len(negatives_selected) < total_size - positives_num:
        negatives_IDs.append(-1)
        negatives_selected.append(empty_ids.clone())

    context_IDs = torch.tensor(positives_IDs + negatives_IDs, dtype=torch.int64)
    input_ids = torch.stack([t for t in positives_selected + negatives_selected], dim=0)
    assert len(context_IDs) == len(input_ids)

    return context_IDs, input_ids, answer_starts_tensor, answer_ends_tensor, answer_mask
Exemplo n.º 3
0
def create_reader_input(
    wiki_data: TokenizedWikipediaPassages,
    tensorizer: Tensorizer,
    samples: List[ReaderSample],
    passages_per_question: int,
    max_length: int,
    max_n_answers: int,
    is_train: bool,
    shuffle: bool,
) -> ReaderBatch:
    """
    Creates a reader batch instance out of a list of ReaderSample-s. This is compatible with `GeneralDataset`.
    :param wiki_data: all tokenized wikipedia passages
    :param tensorizer: initialized tensorizer (which contains the tokenizer)
    :param samples: list of samples to create the batch for
    :param passages_per_question: amount of passages for every question in a batch
    :param max_length: max model input sequence length
    :param max_n_answers: max num of answers per single question
    :param is_train: if the samples are for a train set
    :param shuffle: should passages selection be randomized
    :return: ReaderBatch instance
    """
    context_IDs = []
    input_ids = []
    start_positions = []
    end_positions = []
    answers_masks = []

    empty_sequence = torch.Tensor().new_full((max_length,), tensorizer.get_pad_id(), dtype=torch.long)

    for sample in samples:
        if is_train:
            positive_ctxs = sample.positive_passages
            negative_ctxs = sample.negative_passages
        else:
            positive_ctxs = []
            negative_ctxs = sample.positive_passages + sample.negative_passages
            # Need to re-sort samples based on their scores
            negative_ctxs = sorted(negative_ctxs, key=lambda x: x.score, reverse=True)
        question_token_ids = sample.question_token_ids

        sample_tensors = _create_question_passages_tensors(
            wiki_data,
            question_token_ids,
            tensorizer,
            positive_ctxs,
            negative_ctxs,
            passages_per_question,
            empty_sequence,
            max_n_answers,
            is_train,
            is_random=shuffle
        )

        if not sample_tensors:
            logger.warning('No valid passages combination for question=%s ', sample.question)
            continue
        context_ID, sample_input_ids, starts_tensor, ends_tensor, answer_mask = sample_tensors

        context_IDs.append(context_ID)
        input_ids.append(sample_input_ids)

        if is_train:
            start_positions.append(starts_tensor)
            end_positions.append(ends_tensor)
            answers_masks.append(answer_mask)

    context_IDs = torch.cat([IDs.unsqueeze(0) for IDs in context_IDs], dim=0)  # (N, M)
    input_ids = torch.cat([ids.unsqueeze(0) for ids in input_ids], dim=0)  # (N, M)

    if is_train:
        start_positions = torch.stack(start_positions, dim=0)
        end_positions = torch.stack(end_positions, dim=0)
        answers_masks = torch.stack(answers_masks, dim=0)

    return ReaderBatch(context_IDs, input_ids, start_positions, end_positions, answers_masks)