Beispiel #1
0
    def prepro(self, context, question):
        context = context.replace("''", '" ').replace("``", '" ')
        context_tokens = word_tokenize(context)
        context_chars = [list(token) for token in context_tokens]
        spans = convert_idx(context, context_tokens)
        ques = question.replace("''", '" ').replace("``", '" ')
        ques_tokens = word_tokenize(ques)
        ques_chars = [list(token) for token in ques_tokens]

        context_idxs = np.zeros([1, len(context_tokens)], dtype=np.int32)
        context_char_idxs = np.zeros([1, len(context_tokens), char_limit],
                                     dtype=np.int32)
        ques_idxs = np.zeros([1, len(ques_tokens)], dtype=np.int32)
        ques_char_idxs = np.zeros([1, len(ques_tokens), char_limit],
                                  dtype=np.int32)

        def _get_word(word):
            for each in (word, word.lower(), word.capitalize(), word.upper()):
                if each in self.word2idx_dict:
                    return self.word2idx_dict[each]
            return 1

        def _get_char(char):
            if char in self.char2idx_dict:
                return self.char2idx_dict[char]
            return 1

        for i, token in enumerate(context_tokens):
            context_idxs[0, i] = _get_word(token)

        for i, token in enumerate(ques_tokens):
            ques_idxs[0, i] = _get_word(token)

        for i, token in enumerate(context_chars):
            for j, char in enumerate(token):
                if j == char_limit:
                    break
                context_char_idxs[0, i, j] = _get_char(char)

        for i, token in enumerate(ques_chars):
            for j, char in enumerate(token):
                if j == char_limit:
                    break
                ques_char_idxs[0, i, j] = _get_char(char)

        print('ques_idxs:', ques_idxs)
        print('ques_char_idxs:', ques_char_idxs)

        return spans, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs
    def prepro(self, context, question):
        context = context.replace("''", '" ').replace("``", '" ')
        context = pre_proc(context)
        context_tokens, context_tags, context_ents, context_lemmas = word_tokenize(
            context)  #text, tag, ent, lemma
        context_lower_tokens = [w.lower() for w in context_tokens]
        context_chars = [list(token) for token in context_tokens]
        spans = convert_idx(context, context_tokens)

        counter_ = Counter(context_lower_tokens)
        tf_total = len(context_lower_tokens)
        context_tf = [
            float(counter_[w]) / float(tf_total) for w in context_lower_tokens
        ]

        ques = question.replace("''", '" ').replace("``", '" ')
        ques = pre_proc(ques)
        ques_tokens, ques_tags, ques_ents, ques_lemmas = word_tokenize(ques)
        ques_lower_tokens = [w.lower() for w in ques_tokens]
        ques_chars = [list(token) for token in ques_tokens]

        ques_lemma = {
            lemma if lemma != '-PRON-' else lower
            for lemma, lower in zip(ques_lemmas, ques_lower_tokens)
        }

        ques_tokens_set = set(ques_tokens)
        ques_lower_tokens_set = set(ques_lower_tokens)
        match_origin = [w in ques_tokens_set for w in context_tokens]
        match_lower = [
            w in ques_lower_tokens_set for w in context_lower_tokens
        ]
        match_lemma = [
            (c_lemma if c_lemma != '-PRON-' else c_lower) in ques_lemma
            for (c_lemma, c_lower) in zip(context_lemmas, context_lower_tokens)
        ]

        example = {
            "context_tokens": context_tokens,
            "context_chars": context_chars,
            "match_origin": match_origin,
            "match_lower": match_lower,
            "match_lemma": match_lemma,
            "context_pos": context_tags,
            "context_ner": context_ents,
            "context_tf": context_tf,
            "ques_tokens": ques_tokens,
            "ques_pos": ques_tags,
            "ques_ner": ques_ents,
            "ques_chars": ques_chars
        }

        context_idxs = np.zeros([self.para_limit], dtype=np.int32)
        context_elmo_tokens = example['context_tokens']

        match_origin = np.zeros([self.para_limit], dtype=np.int32)
        match_lower = np.zeros([self.para_limit], dtype=np.int32)
        match_lemma = np.zeros([self.para_limit], dtype=np.int32)
        context_tf = np.zeros([self.para_limit], dtype=np.float32)
        context_pos_idxs = np.zeros([self.para_limit], dtype=np.int32)
        context_ner_idxs = np.zeros([self.para_limit], dtype=np.int32)
        context_char_idxs = np.zeros([self.para_limit, self.char_limit],
                                     dtype=np.int32)

        ques_idxs = np.zeros([self.ques_limit], dtype=np.int32)
        ques_elmo_tokens = example['ques_tokens']
        ques_pos_idxs = np.zeros([self.ques_limit], dtype=np.int32)
        ques_ner_idxs = np.zeros([self.ques_limit], dtype=np.int32)
        ques_char_idxs = np.zeros([self.ques_limit, self.char_limit],
                                  dtype=np.int32)

        def _get_word(word):
            for each in (word, word.lower(), word.capitalize(), word.upper()):
                if each in self.word2idx_dict:
                    return self.word2idx_dict[each]
            return 1

        def _get_pos(pos):
            if pos in self.pos2idx_dict:
                return self.pos2idx_dict[pos]
            return 1

        def _get_ner(ner):
            if ner in self.ner2idx_dict:
                return self.ner2idx_dict[ner]
            return 1

        def _get_char(char):
            if char in self.char2idx_dict:
                return self.char2idx_dict[char]
            return 1

        for i, token in enumerate(example["context_tokens"]):
            context_idxs[i] = _get_word(token)
        for i, match in enumerate(example["match_origin"]):
            match_origin[i] = 1 if match == True else 0
        for i, match in enumerate(example["match_lower"]):
            match_lower[i] = 1 if match == True else 0
        for i, match in enumerate(example["match_lemma"]):
            match_lemma[i] = 1 if match == True else 0

        for i, tf in enumerate(example['context_tf']):
            context_tf[i] = tf

        for i, pos in enumerate(example['context_pos']):
            context_pos_idxs[i] = _get_pos(pos)
        for i, ner in enumerate(example['context_ner']):
            context_ner_idxs[i] = _get_ner(ner)

        for i, token in enumerate(example["ques_tokens"]):
            ques_idxs[i] = _get_word(token)

        for i, pos in enumerate(example['ques_pos']):
            ques_pos_idxs[i] = _get_pos(pos)
        for i, ner in enumerate(example['ques_ner']):
            ques_ner_idxs[i] = _get_ner(ner)

        for i, token in enumerate(example["context_chars"]):
            for j, char in enumerate(token):
                if j == self.char_limit:
                    break
                context_char_idxs[i, j] = _get_char(char)

        for i, token in enumerate(example["ques_chars"]):
            for j, char in enumerate(token):
                if j == self.char_limit:
                    break
                ques_char_idxs[i, j] = _get_char(char)

        passage_ids = torch.LongTensor([context_idxs.tolist()]).to(self.device)
        passage_char_ids = torch.LongTensor([context_char_idxs.tolist()
                                             ]).to(self.device)
        passage_pos_ids = torch.LongTensor([context_pos_idxs.tolist()
                                            ]).to(self.device)
        passage_ner_ids = torch.LongTensor([context_ner_idxs.tolist()
                                            ]).to(self.device)
        passage_match_origin = torch.FloatTensor([match_origin.tolist()
                                                  ]).to(self.device)
        passage_match_lower = torch.FloatTensor([match_lower.tolist()
                                                 ]).to(self.device)
        passage_match_lemma = torch.FloatTensor([match_lemma.tolist()
                                                 ]).to(self.device)
        passage_tf = torch.FloatTensor([context_tf.tolist()]).to(self.device)

        ques_ids = torch.LongTensor([ques_idxs.tolist()]).to(self.device)
        ques_char_ids = torch.LongTensor([ques_char_idxs.tolist()
                                          ]).to(self.device)
        ques_pos_ids = torch.LongTensor([ques_pos_idxs.tolist()
                                         ]).to(self.device)
        ques_ner_ids = torch.LongTensor([ques_ner_idxs.tolist()
                                         ]).to(self.device)

        passage_elmo_ids = batch_to_ids([context_elmo_tokens]).to(self.device)
        question_elmo_ids = batch_to_ids([ques_elmo_tokens]).to(self.device)

        p_lengths = passage_ids.ne(0).long().sum(1)
        q_lengths = ques_ids.ne(0).long().sum(1)

        passage_maxlen = int(torch.max(p_lengths, 0)[0])
        ques_maxlen = int(torch.max(q_lengths, 0)[0])

        passage_ids = passage_ids[:, :passage_maxlen]
        passage_char_ids = passage_char_ids[:, :passage_maxlen]
        passage_pos_ids = passage_pos_ids[:, :passage_maxlen]
        passage_ner_ids = passage_ner_ids[:, :passage_maxlen]
        passage_match_origin = passage_match_origin[:, :passage_maxlen]
        passage_match_lower = passage_match_lower[:, :passage_maxlen]
        passage_match_lemma = passage_match_lemma[:, :passage_maxlen]
        passage_tf = passage_tf[:, :passage_maxlen]
        ques_ids = ques_ids[:, :ques_maxlen]
        ques_char_ids = ques_char_ids[:, :ques_maxlen]
        ques_pos_ids = ques_pos_ids[:, :ques_maxlen]
        ques_ner_ids = ques_ner_ids[:, :ques_maxlen]

        p_mask = self.compute_mask(passage_ids)
        q_mask = self.compute_mask(ques_ids)

        return (passage_ids, passage_char_ids, passage_pos_ids,
                passage_ner_ids, passage_match_origin.unsqueeze(2).float(),
                passage_match_lower.unsqueeze(2).float(),
                passage_match_lemma.unsqueeze(2).float(),
                passage_tf.unsqueeze(2), p_mask, ques_ids, ques_char_ids,
                ques_pos_ids, ques_ner_ids, q_mask, passage_elmo_ids,
                question_elmo_ids), spans
def process_file(config,
                 squad_data,
                 data_type,
                 word_counter=None,
                 char_counter=None,
                 bpe_counter=None,
                 pos_counter=None,
                 remove_unicode=True,
                 bpe_model=None,
                 pos_model=None,
                 is_test=False):
    print("Generating {} examples...".format(data_type))
    para_limit = config.test_para_limit if is_test else config.para_limit
    ques_limit = config.test_ques_limit if is_test else config.ques_limit
    examples = []
    eval_examples = {}
    total = 0

    source = squad_data
    for article in source["data"]:
        for para in tqdm(article["paragraphs"]):
            context_raw = para['context']
            context, r2p, p2r = preprocess_string(
                para["context"],
                unicode_mapping=True,
                remove_unicode=remove_unicode)
            context_tokens = word_tokenize(context)[:para_limit]
            context_chars = [list(token) for token in context_tokens]
            context_bpe = []
            context_pos = []
            if bpe_model is not None:
                context_bpe = [
                    bpe_model.segment(token).split(' ')
                    for token in context_tokens
                ]

            if pos_model is not None:
                context_pos = [
                    get_pos(token, pos_model) for token in context_tokens
                ]

            spans = convert_idx(context, context_tokens)
            if word_counter is not None:
                for token in context_tokens:
                    word_counter[token] += len(para["qas"])
                    if char_counter is not None:
                        for char in token:
                            char_counter[char] += len(para["qas"])

            if bpe_counter is not None:
                for token in context_bpe:
                    for bpe in token:
                        bpe_counter[bpe] += len(para["qas"])

            if pos_counter is not None:
                for pos in context_pos:
                    pos_counter[pos] += len(para["qas"])

            for qa in para["qas"]:
                total += 1
                ques = preprocess_string(qa["question"],
                                         remove_unicode=remove_unicode)
                ques_tokens = word_tokenize(ques)[:ques_limit]
                ques_chars = [list(token) for token in ques_tokens]
                ques_bpe = []
                ques_pos = []
                if bpe_model is not None:
                    ques_bpe = [
                        bpe_model.segment(token).split(' ')
                        for token in ques_tokens
                    ]

                if pos_model is not None:
                    ques_pos = [
                        get_pos(token, pos_model) for token in ques_tokens
                    ]

                if word_counter is not None:
                    for token in ques_tokens:
                        word_counter[token] += 1
                        if char_counter is not None:
                            for char in token:
                                char_counter[char] += 1

                if bpe_counter is not None:
                    for token in context_bpe:
                        for bpe in token:
                            bpe_counter[bpe] += 1

                if pos_counter is not None:
                    for pos in ques_pos:
                        pos_counter[pos] += 1

                y1s, y2s = [], []
                answer_texts = []
                for answer in qa["answers"]:
                    answer_text = preprocess_string(
                        answer["text"], remove_unicode=remove_unicode)
                    # convert answer start index to index in preprocessed context
                    answer_start = r2p[answer['answer_start']]
                    answer_end = answer_start + len(answer_text)
                    answer_texts.append(answer_text)
                    answer_span = []
                    for idx, span in enumerate(spans):
                        if not (answer_end <= span[0]
                                or answer_start >= span[1]):
                            answer_span.append(idx)
                    if len(answer_span) == 0:
                        # there is no answer in context_tokens (mb because of para_limit)
                        continue
                    y1, y2 = answer_span[0], answer_span[-1]
                    y1s.append(y1)
                    y2s.append(y2)
                if len(answer_texts) == 0 and len(qa["answers"]) != 0:
                    # all answers are in the end of long context
                    # skipping such QAs
                    continue
                example = {
                    "context_tokens": context_tokens,
                    "context_chars": context_chars,
                    "context_bpe": context_bpe,
                    "context_pos": context_pos,
                    "ques_tokens": ques_tokens,
                    "ques_chars": ques_chars,
                    "ques_bpe": ques_bpe,
                    "ques_pos": ques_pos,
                    "y1s": y1s,
                    "y2s": y2s,
                    "id": total
                }
                examples.append(example)
                eval_examples[str(total)] = {
                    "context": context,
                    "spans": spans,
                    "answers": answer_texts,
                    "uuid": qa["id"],
                    "context_raw": context_raw,
                    "raw2prepro": r2p,
                    "prepro2raw": p2r,
                }
    random.shuffle(examples)
    print("{} questions in total".format(len(examples)))
    return examples, eval_examples