Exemplo n.º 1
0
        def batch_generator():
            todo = list(i for i in range(len(q_ids)) if is_eval or i not in not_allowed)
            if not is_eval:
                self._rng.shuffle(todo)
            while todo:
                support_lengths = list()
                question_lengths = list()
                wiq = list()
                spans = list()
                span2question = []
                offsets = []
                at_spans = []

                unique_words, unique_word_lengths, question2unique, support2unique = \
                    unique_words_with_chars(q_tokenized, s_tokenized, self.char_vocab, todo[:self.batch_size])

                # we have to create batches here and cannot precompute them because of the batch-specific wiq feature
                for i, j in enumerate(todo[:self.batch_size]):
                    support = s_ids[j]
                    for k in range(len(support)):
                        emb_supports[i, k] = self._get_emb(support[k])
                    question = q_ids[j]
                    for k in range(len(question)):
                        emb_questions[i, k] = self._get_emb(question[k])
                    support_lengths.append(s_lengths[j])
                    question_lengths.append(q_lengths[j])
                    aps = [s for s in answer_spans[j] if s[1] - s[0] <= _max_span_size or is_eval]
                    spans.extend(aps)
                    span2question.extend(i for _ in aps)
                    wiq.append(word_in_question[j])
                    offsets.append(token_offsets[j])
                    at_spans.append(answertype_spans[j])

                batch_size = len(question_lengths)
                output = {
                    XQAPorts.unique_word_chars: unique_words,
                    XQAPorts.unique_word_char_length: unique_word_lengths,
                    XQAPorts.question_words2unique: question2unique,
                    XQAPorts.support_words2unique: support2unique,
                    XQAPorts.emb_support: emb_supports[:batch_size, :max(support_lengths), :],
                    XQAPorts.support_length: support_lengths,
                    XQAPorts.emb_question: emb_questions[:batch_size, :max(question_lengths), :],
                    XQAPorts.question_length: question_lengths,
                    XQAPorts.word_in_question: wiq,
                    XQAPorts.answer_span: spans,
                    XQAPorts.correct_start_training: [] if is_eval else [s[0] for s in spans],
                    XQAPorts.answer2question: span2question,
                    XQAPorts.answer2question_training: [] if is_eval else span2question,
                    XQAPorts.keep_prob: 1.0 if is_eval else 1 - self.dropout,
                    XQAPorts.is_eval: is_eval,
                    XQAPorts.token_char_offsets: offsets,
                    CBOWXqaPorts.answer_type_span: at_spans
                }

                # we can only numpify in here, because bucketing is not possible prior
                batch = numpify(output, keys=[XQAPorts.unique_word_chars,
                                              XQAPorts.question_words2unique, XQAPorts.support_words2unique,
                                              XQAPorts.word_in_question, XQAPorts.token_char_offsets])
                todo = todo[self.batch_size:]
                yield batch
Exemplo n.º 2
0
    def __call__(
            self,
            qa_settings: List[QASetting]) -> Mapping[TensorPort, np.ndarray]:
        q_tokenized, q_ids, q_lengths, s_tokenized, s_ids, s_lengths, \
        word_in_question, token_offsets, answer_spans,slot= prepare_data(qa_settings, self.vocab,
                                                                     self.config.get("lowercase", False),
                                                                     with_answers=False)

        unique_words, unique_word_lengths, question2unique, support2unique = \
            unique_words_with_chars(q_tokenized, s_tokenized, self.char_vocab)

        batch_size = len(qa_settings)
        emb_supports = np.zeros(
            [batch_size, max(s_lengths), self.emb_matrix.shape[1]])
        emb_questions = np.zeros(
            [batch_size, max(q_lengths), self.emb_matrix.shape[1]])

        for i, q in enumerate(q_ids):
            for k, v in enumerate(s_ids[i]):
                emb_supports[i, k] = self._get_emb(v)
            for k, v in enumerate(q):
                emb_questions[i, k] = self._get_emb(v)

        output = {
            XQAPorts.unique_word_chars: unique_words,
            XQAPorts.unique_word_char_length: unique_word_lengths,
            XQAPorts.question_words2unique: question2unique,
            XQAPorts.support_words2unique: support2unique,
            XQAPorts.emb_support: emb_supports,
            XQAPorts.support_length: s_lengths,
            XQAPorts.emb_question: emb_questions,
            XQAPorts.question_length: q_lengths,
            XQAPorts.slot_list: slot,
            XQAPorts.word_in_question: word_in_question,
            XQAPorts.token_char_offsets: token_offsets,
            Ports.Input.question: q_ids
        }

        output = numpify(output,
                         keys=[
                             XQAPorts.unique_word_chars,
                             XQAPorts.question_words2unique,
                             XQAPorts.support_words2unique,
                             XQAPorts.word_in_question,
                             XQAPorts.token_char_offsets, XQAPorts.slot_list,
                             Ports.Input.question
                         ])

        return output