Ejemplo n.º 1
0
def inference_answer(question: str, context: str, input_ids: List[int],
                     token_type_ids: List[int], start_pos: int, end_pos: int,
                     tokenizer: BertTokenizerFast) -> str:
    """ Inference fucntion for the answer.

    Because the tokenizer lowers the capital letters and splits punctuation marks,
    you may get wrong answer words if you detokenize it directly.
    For example, if you encode "$5.000 Dollars" and decode it, you get different words from the orignal.

    "$5.00 USD" --(Tokenize)--> ["$", "5", ".", "00", "usd"] --(Detokenize)--> "$ 5. 00 usd"

    Thus, you should find the original words in the context by the start and end token positions of the answer.
    Implement the function inferencing the answer from the context and the answer token postion.

    Note 1: We have already implmented direct decoding so you can skip this problem if you want.

    Note 2: When we implement squad_feature, we have arbitrarily split tokens if the answer is a subword,
            so it is very tricky to extract the original word by start_pos and end_pos.`
            However, as None is entered into the answer when evaluating,
            you can assume the word tokens follow general tokenizing rule in this problem.
            In fact, the most appropriate solution is storing the character index when tokenizing them.

    Hint: You can find a simple solution if you carefully search the documentation of the transformers library.
    Library Link: https://huggingface.co/transformers/index.html

    Arguments:
    question -- Question string
    context -- Context string

    input_ids -- Input ids
    token_type_ids -- Token type ids
    start_pos -- Predicted start token position of the answer
    end_pos -- Predicted end token position of the answer

    tokenizer -- Tokenizer to encode and decode the string

    Return:
    answer -- Answer string
    """
    ### YOUR CODE HERE (~4 lines)
    answer = input_ids[start_pos:end_pos + 1]
    answer: str = tokenizer.decode(answer)
    encoded_context = tokenizer.encode_plus(question,
                                            context,
                                            return_offsets_mapping=True)
    answer_char_pos = encoded_context['offset_mapping'][start_pos:end_pos + 1]
    answer = context[answer_char_pos[0][0]:answer_char_pos[-1][1]]
    ### END YOUR CODE

    return answer
Ejemplo n.º 2
0
class KorquadDataset(Dataset):
    def __init__(self, train=True):
        if train:
            path = "/data/KorQuAD_v1.0_train.json"
            db_name = "korquad_train.qas"
        else:
            path = "/data/KorQuAD_v1.0_dev.json"
            db_name = "korquad_dev.qas"
        self.tokenizer = BertTokenizerFast("wiki-vocab.txt")

        data = json.load(open(path, encoding="utf-8"))["data"]

        self.qas = []
        if not os.path.exists(db_name):
            with open(db_name, "wb") as f:
                self.mecab = Mecab()
                ignored_cnt = 0
                for paragraphs in tqdm(data):
                    paragraphs = paragraphs["paragraphs"]
                    for paragraph in paragraphs:
                        _context = paragraph["context"]
                        for qa in paragraph["qas"]:
                            question = qa["question"]
                            answer = qa["answers"][0]["text"]
                            (
                                input_ids,
                                token_type_ids,
                                start_token_pos,
                                end_token_pos,
                            ) = self.extract_features(
                                _context,
                                question,
                                answer,
                                qa["answers"][0]["answer_start"],
                            )
                            if len(input_ids) > 512:
                                if not train:
                                    pickle.dump(
                                        (
                                            input_ids,
                                            token_type_ids,
                                            start_token_pos,
                                            end_token_pos,
                                        ),
                                        f,
                                    )
                            else:
                                if train:
                                    pickle.dump(
                                        (
                                            input_ids,
                                            token_type_ids,
                                            start_token_pos,
                                            end_token_pos,
                                        ),
                                        f,
                                    )
                                else:
                                    pickle.dump(
                                        (
                                            input_ids,
                                            token_type_ids,
                                            start_token_pos,
                                            end_token_pos,
                                        ),
                                        f,
                                    )

        with open(db_name, "rb") as f:
            while True:
                try:
                    data = pickle.load(f)
                    self.qas.append(data)
                except EOFError:
                    break
            print(len(self.qas))

    @property
    def token_num(self):
        return self.tokenizer.vocab_size

    def __len__(self):
        return len(self.qas)

    def encode(self, line):
        converted_results = map(
            lambda x: x[1:-1],
            self.tokenizer.batch_encode_plus(line)["input_ids"])
        return [2, *chain.from_iterable(converted_results), 3]

    def decode(self, token_ids):
        decode_str = self.tokenizer.decode(token_ids, skip_special_tokens=True)
        return decode_str

    def __getitem__(self, idx):
        return self.qas[idx]

    def extract_features(self, context, question, answer, start_char_pos):
        if answer is None:
            # use encode_plus function in tokenizer
            tokenized_q = self.tokenize(question)
            tokenized_c = self.tokenize(context)
            input_ids = [*tokenized_q, *tokenized_c[1:]]
            token_type_ids = [
                *[0 for _ in tokenized_q], *[1 for _ in tokenized_c[1:]]
            ]
            start_token_pos: int = None
            end_token_pos: int = None
        else:
            # Split sentences using len(answer) and start_char_pos
            context_front = context[:start_char_pos]
            context_back = context[start_char_pos + len(answer):]
            q_ids = self.tokenize(question)
            f_ids = self.tokenize(context_front)
            a_ids = self.tokenize(answer)
            b_ids = self.tokenize(context_back)

            # For processing subwords
            if context_front != "" and context_front[-1] != " ":
                a_ids = [a_ids[0], a_ids[1], *a_ids[2:]]
            if context_back != "" and context_back[0] != " ":
                b_ids = [b_ids[0], b_ids[1], *b_ids[2:]]

            # Manually generate input_ids, token_type_ids and start/end_token_pos (carefully remove [CLS] and [SEP])
            input_ids = [*q_ids, *f_ids[1:-1], *a_ids[1:-1], *b_ids[1:]]
            token_type_ids = [
                *[0 for _ in q_ids],
                *[1 for _ in f_ids[1:-1]],
                *[1 for _ in a_ids[1:-1]],
                *[1 for _ in b_ids[1:]],
            ]
            start_token_pos = len(q_ids) + (len(f_ids) - 2)
            end_token_pos = len(q_ids) + (len(f_ids) - 2) + (len(a_ids) -
                                                             2) - 1

        return input_ids, token_type_ids, start_token_pos, end_token_pos

    def tokenize(self, sentence):
        if len(sentence) == 0:
            return [2, 3]
        return self.encode(
            [j2hcj(h2j(word)) for word in self.mecab.morphs(sentence)])

    def collate_fn(self, samples):
        input_ids, token_type_ids, start_pos, end_pos = zip(*samples)
        attention_mask = [[1] * len(input_id) for input_id in input_ids]

        input_ids = pad_sequence(
            [torch.Tensor(input_id).to(torch.long) for input_id in input_ids],
            padding_value=0,
            batch_first=True,
        )
        token_type_ids = pad_sequence(
            [
                torch.Tensor(token_type_id).to(torch.long)
                for token_type_id in token_type_ids
            ],
            padding_value=1,
            batch_first=True,
        )
        attention_mask = pad_sequence(
            [torch.Tensor(mask).to(torch.long) for mask in attention_mask],
            padding_value=0,
            batch_first=True,
        )

        start_pos = torch.Tensor(start_pos).to(torch.long)
        end_pos = torch.Tensor(end_pos).to(torch.long)

        return input_ids, attention_mask, token_type_ids, start_pos, end_pos