Beispiel #1
0
    def forward(self,
                unique_word_chars,
                unique_word_lengths,
                sequences_as_uniqs=None):
        long_tensor = torch.cuda.LongTensor if torch.cuda.device_count(
        ) > 0 else torch.LongTensor
        embedded_chars = self._embeddings(unique_word_chars.type(long_tensor))
        # [N, S, L]
        conv_out = self._conv(embedded_chars.transpose(1, 2))
        # [N, L]
        conv_mask = misc.mask_for_lengths(unique_word_lengths)
        conv_out = conv_out + conv_mask.unsqueeze(1)
        embedded_words = conv_out.max(2)[0]

        if sequences_as_uniqs is None:
            return embedded_words
        else:
            if not isinstance(sequences_as_uniqs, list):
                sequences_as_uniqs = [sequences_as_uniqs]

            all_embedded = []
            for word_idxs in sequences_as_uniqs:
                all_embedded.append(
                    functional.embedding(word_idxs.type(long_tensor),
                                         embedded_words))
            return all_embedded
Beispiel #2
0
    def forward(self, emb_question, question_length, emb_support,
                support_length, unique_word_chars, unique_word_char_length,
                question_words2unique, support_words2unique, word_in_question,
                correct_start, answer2support, is_eval):
        """fast_qa model

        Args:
            emb_question: [Q, L_q, N]
            question_length: [Q]
            emb_support: [Q, L_s, N]
            support_length: [Q]
            unique_word_chars
            unique_word_char_length
            question_words2unique
            support_words2unique
            word_in_question: [Q, L_s]
            correct_start: [A], only during training, i.e., is_eval=False
            answer2question: [A], only during training, i.e., is_eval=False
            is_eval: []

        Returns:
            start_scores [B, L_s, N], end_scores [B, L_s, N], span_prediction [B, 2]
        """
        # Some helpers
        float_tensor = torch.cuda.FloatTensor if emb_question.is_cuda else torch.FloatTensor
        long_tensor = torch.cuda.LongTensor if emb_question.is_cuda else torch.LongTensor
        batch_size = question_length.data.shape[0]
        max_question_length = question_length.max().data[0]
        support_mask = misc.mask_for_lengths(support_length)
        question_binary_mask = misc.mask_for_lengths(question_length,
                                                     mask_right=False,
                                                     value=1.0)

        if self._with_char_embeddings:
            # compute combined embeddings
            [char_emb_question, char_emb_support] = self._conv_char_embedding(
                unique_word_chars, unique_word_char_length,
                [question_words2unique, support_words2unique])

            emb_question = torch.cat([emb_question, char_emb_question], 2)
            emb_support = torch.cat([emb_support, char_emb_support], 2)

        # compute encoder features
        question_features = torch.autograd.Variable(
            torch.ones(batch_size, max_question_length, 2, out=float_tensor()))
        question_features = question_features.type_as(emb_question)

        v_wiqw = self._v_wiq_w
        # [B, L_q, L_s]
        wiq_w = torch.matmul(emb_question * v_wiqw,
                             emb_support.transpose(1, 2))
        # [B, L_q, L_s]
        wiq_w = wiq_w + support_mask.unsqueeze(1)
        wiq_w = F.softmax(wiq_w.view(batch_size * max_question_length, -1),
                          dim=1).view(batch_size, max_question_length, -1)
        # [B, L_s]
        wiq_w = torch.matmul(question_binary_mask.unsqueeze(1),
                             wiq_w).squeeze(1)

        # [B, L , 2]
        support_features = torch.stack([word_in_question, wiq_w], dim=2)

        if self._with_char_embeddings:
            # highway layer to allow for interaction between concatenated embeddings
            emb_question = self._embedding_projection(emb_question)
            emb_support = self._embedding_projection(emb_support)
            emb_question = self._embedding_highway(emb_question)
            emb_support = self._embedding_highway(emb_support)

        # dropout
        dropout = self._shared_resources.config.get("dropout", 0.0)
        emb_question = F.dropout(emb_question, dropout, training=not is_eval)
        emb_support = F.dropout(emb_support, dropout, training=not is_eval)

        # extend embeddings with features
        emb_question_ext = torch.cat([emb_question, question_features], 2)
        emb_support_ext = torch.cat([emb_support, support_features], 2)

        # encode question and support
        # [B, L, 2 * size]
        encoded_question = self._bilstm(emb_question_ext)[0]
        encoded_support = self._bilstm(emb_support_ext)[0]

        # [B, L, size]
        encoded_support = F.tanh(
            F.linear(encoded_support, self._support_projection))
        encoded_question = F.tanh(
            F.linear(encoded_question, self._question_projection))

        start_scores, end_scores, predicted_start_pointer, predicted_end_pointer = \
            self._answer_layer(encoded_question, question_length, encoded_support, support_length,
                               correct_start, answer2support, is_eval)

        # no multi paragraph support yet
        doc_idx = torch.autograd.Variable(
            torch.zeros(predicted_start_pointer.data.shape[0],
                        out=long_tensor()))
        span = torch.stack(
            [doc_idx, predicted_start_pointer, predicted_end_pointer], 1)

        return start_scores, end_scores, span
Beispiel #3
0
    def forward(self, encoded_question, question_length, encoded_support,
                support_length, correct_start, answer2question, is_eval):
        # casting
        long_tensor = torch.cuda.LongTensor if encoded_question.is_cuda else torch.LongTensor
        answer2question = answer2question.type(long_tensor)

        # computing single time attention over question
        attention_scores = self._linear_question_attention(encoded_question)
        q_mask = misc.mask_for_lengths(question_length)
        attention_scores = attention_scores.squeeze(2) + q_mask
        question_attention_weights = F.softmax(attention_scores, dim=1)
        question_state = torch.matmul(question_attention_weights.unsqueeze(1),
                                      encoded_question).squeeze(1)

        # Prediction
        # start
        start_input = torch.cat(
            [question_state.unsqueeze(1) * encoded_support, encoded_support],
            2)

        q_start_state = self._linear_q_start(
            start_input) + self._linear_q_start_q(question_state).unsqueeze(1)
        start_scores = self._linear_start_scores(
            F.relu(q_start_state)).squeeze(2)

        support_mask = misc.mask_for_lengths(support_length)
        start_scores = start_scores + support_mask
        _, predicted_start_pointer = start_scores.max(1)

        def align(t):
            return torch.index_select(t, 0, answer2question)

        if is_eval:
            start_pointer = predicted_start_pointer
        else:
            # use correct start during training, because p(end|start) should be optimized
            start_pointer = correct_start.type(long_tensor)
            predicted_start_pointer = align(predicted_start_pointer)
            start_scores = align(start_scores)
            start_input = align(start_input)
            encoded_support = align(encoded_support)
            question_state = align(question_state)
            support_mask = align(support_mask)

        # end
        u_s = []
        for b, p in enumerate(start_pointer):
            u_s.append(encoded_support[b, p.data[0]])
        u_s = torch.stack(u_s)

        end_input = torch.cat(
            [encoded_support * u_s.unsqueeze(1), start_input], 2)

        q_end_state = self._linear_q_end(end_input) + self._linear_q_end_q(
            question_state).unsqueeze(1)
        end_scores = self._linear_end_scores(F.relu(q_end_state)).squeeze(2)

        end_scores = end_scores + support_mask

        max_support = support_length.max().data[0]

        if is_eval:
            end_scores += misc.mask_for_lengths(start_pointer,
                                                max_support,
                                                mask_right=False)

        _, predicted_end_pointer = end_scores.max(1)

        return start_scores, end_scores, predicted_start_pointer, predicted_end_pointer