예제 #1
0
    def forward(self, context, question, context_char=None, question_char=None, context_f=None, question_f=None):
        """
        context_char and question_char not used
        """

        # get embedding: (seq_len, batch, embedding_size)
        context_vec, context_mask = self.embedding.forward(context)
        question_vec, question_mask = self.embedding.forward(question)

        # encode: (seq_len, batch, hidden_size)
        context_encode, _ = self.encoder.forward(context_vec, context_mask)
        question_encode, _ = self.encoder.forward(question_vec, question_mask)

        # match lstm: (seq_len, batch, hidden_size)
        qt_aware_ct, qt_aware_last_hidden, match_para = self.match_rnn.forward(context_encode, context_mask,
                                                                               question_encode, question_mask)
        vis_param = {'match': match_para}

        # pointer net: (answer_len, batch, context_len)
        ans_range_prop = self.pointer_net.forward(qt_aware_ct, context_mask)
        ans_range_prop = ans_range_prop.transpose(0, 1)

        # answer range
        if not self.training and self.enable_search:
            ans_range = answer_search(ans_range_prop, context_mask)
        else:
            _, ans_range = torch.max(ans_range_prop, dim=2)

        return ans_range_prop, ans_range, vis_param
예제 #2
0
    def forward(self, context, question, context_char=None, question_char=None, context_f=None, question_f=None):
        assert context_char is not None and question_char is not None and context_f is not None \
               and question_f is not None

        vis_param = {}

        # (seq_len, batch, additional_feature_size)
        context_f = context_f.transpose(0, 1)
        question_f = question_f.transpose(0, 1)

        # word-level embedding: (seq_len, batch, word_embedding_size)
        context_vec, context_mask = self.embedding.forward(context)
        question_vec, question_mask = self.embedding.forward(question)

        # char-level embedding: (seq_len, batch, char_embedding_size)
        context_emb_char, context_char_mask = self.char_embedding.forward(context_char)
        question_emb_char, question_char_mask = self.char_embedding.forward(question_char)

        context_vec_char = self.char_encoder.forward(context_emb_char, context_char_mask, context_mask)
        question_vec_char = self.char_encoder.forward(question_emb_char, question_char_mask, question_mask)

        # mix embedding: (seq_len, batch, embedding_size)
        context_vec = torch.cat((context_vec, context_vec_char, context_f), dim=-1)
        question_vec = torch.cat((question_vec, question_vec_char, question_f), dim=-1)

        # encode: (seq_len, batch, hidden_size*2)
        context_encode, _ = self.encoder.forward(context_vec, context_mask)
        question_encode, zs = self.encoder.forward(question_vec, question_mask)

        align_ct = context_encode
        for i in range(self.num_align_hops):
            # align: (seq_len, batch, hidden_size*2)
            qt_align_ct, alpha = self.aligner[i](align_ct, question_encode, question_mask)
            bar_ct = self.aligner_sfu[i](align_ct, torch.cat([qt_align_ct,
                                                              align_ct * qt_align_ct,
                                                              align_ct - qt_align_ct], dim=-1))
            vis_param['match'] = alpha

            # self-align: (seq_len, batch, hidden_size*2)
            ct_align_ct, self_alpha = self.self_aligner[i](bar_ct, context_mask)
            hat_ct = self.self_aligner_sfu[i](bar_ct, torch.cat([ct_align_ct,
                                                                 bar_ct * ct_align_ct,
                                                                 bar_ct - ct_align_ct], dim=-1))
            vis_param['self-match'] = self_alpha

            # aggregation: (seq_len, batch, hidden_size*2)
            align_ct, _ = self.aggregation[i](hat_ct, context_mask)

        # pointer net: (answer_len, batch, context_len)
        for i in range(self.num_ptr_hops):
            ans_range_prop, zs = self.ptr_net[i](align_ct, context_mask, zs)

        # answer range
        ans_range_prop = ans_range_prop.transpose(0, 1)
        if not self.training and self.enable_search:
            ans_range = answer_search(ans_range_prop, context_mask)
        else:
            _, ans_range = torch.max(ans_range_prop, dim=2)

        return ans_range_prop, ans_range, vis_param
예제 #3
0
    def forward(self,
                context,
                question,
                context_char=None,
                question_char=None,
                context_f=None,
                question_f=None):
        """
        context_char and question_char not used
        """

        # get embedding: (seq_len, batch, embedding_size)
        context_vec, context_mask = self.embedding.forward(context)
        question_vec, question_mask = self.embedding.forward(question)

        # encode: (seq_len, batch, hidden_size)
        context_encode, _ = self.encoder.forward(context_vec, context_mask)
        question_encode, _ = self.encoder.forward(question_vec, question_mask)

        # match lstm: (seq_len, batch, hidden_size)
        qt_aware_ct, qt_aware_last_hidden, match_para = self.match_rnn.forward(
            context_encode, context_mask, question_encode, question_mask)
        vis_param = {'match': match_para}

        # pointer net: (answer_len, batch, context_len)
        ans_range_prop = self.pointer_net.forward(qt_aware_ct, context_mask)
        ans_range_prop = ans_range_prop.transpose(0, 1)

        # answer range
        if not self.training and self.enable_search:
            ans_range = answer_search(ans_range_prop, context_mask)
        else:
            _, ans_range = torch.max(ans_range_prop, dim=2)

        return ans_range_prop, ans_range, vis_param
    def forward(self,
                context,
                question,
                context_char=None,
                question_char=None):
        assert context_char is not None and question_char is not None

        # word-level embedding: (seq_len, batch, embedding_size)
        context_vec, context_mask = self.embedding.forward(context)
        question_vec, question_mask = self.embedding.forward(question)

        # char-level embedding: (seq_len, batch, char_embedding_size)
        context_emb_char, context_char_mask = self.char_embedding.forward(
            context_char)
        question_emb_char, question_char_mask = self.char_embedding.forward(
            question_char)

        # word-level encode: (seq_len, batch, hidden_size)
        context_encode, _ = self.encoder.forward(context_vec, context_mask)
        question_encode, _ = self.encoder.forward(question_vec, question_mask)

        # char-level encode: (seq_len, batch, hidden_size)
        context_vec_char = self.char_encoder.forward(context_emb_char,
                                                     context_char_mask,
                                                     context_mask)
        question_vec_char = self.char_encoder.forward(question_emb_char,
                                                      question_char_mask,
                                                      question_mask)

        context_encode = torch.cat((context_encode, context_vec_char), dim=-1)
        question_encode = torch.cat((question_encode, question_vec_char),
                                    dim=-1)

        # match lstm: (seq_len, batch, hidden_size)
        qt_aware_ct, qt_aware_last_hidden, match_para = self.match_rnn.forward(
            context_encode, context_mask, question_encode, question_mask)
        vis_param = {'match': match_para}

        # birnn after self match: (seq_len, batch, hidden_size)
        qt_aware_ct_ag, _ = self.birnn_after_self.forward(
            qt_aware_ct, context_mask)

        # pointer net init hidden: (batch, hidden_size)
        ptr_net_hidden = F.tanh(
            self.init_ptr_hidden.forward(qt_aware_last_hidden))

        # pointer net: (answer_len, batch, context_len)
        ans_range_prop = self.pointer_net.forward(qt_aware_ct_ag, context_mask,
                                                  ptr_net_hidden)
        ans_range_prop = ans_range_prop.transpose(0, 1)

        # answer range
        if not self.training and self.enable_search:
            ans_range = answer_search(ans_range_prop, context_mask)
        else:
            _, ans_range = torch.max(ans_range_prop, dim=2)

        return ans_range_prop, ans_range, vis_param
예제 #5
0
    def forward(self, context, question, context_char=None, question_char=None, context_f=None, question_f=None):
        assert context_char is not None and question_char is not None

        # (seq_len, batch, additional_feature_size)
        context_f = context_f.transpose(0, 1)
        question_f = question_f.transpose(0, 1)

        # word-level embedding: (seq_len, batch, embedding_size)
        context_vec, context_mask = self.embedding.forward(context)
        question_vec, question_mask = self.embedding.forward(question)

        # char-level embedding: (seq_len, batch, char_embedding_size)
        context_emb_char, context_char_mask = self.char_embedding.forward(context_char)
        question_emb_char, question_char_mask = self.char_embedding.forward(question_char)

        # word-level encode: (seq_len, batch, hidden_size)
        context_vec = torch.cat([context_vec, context_f], dim=-1)
        question_vec = torch.cat([question_vec, question_f], dim=-1)
        context_encode, _ = self.encoder.forward(context_vec, context_mask)
        question_encode, _ = self.encoder.forward(question_vec, question_mask)

        # char-level encode: (seq_len, batch, hidden_size)
        context_vec_char = self.char_encoder.forward(context_emb_char, context_char_mask, context_mask)
        question_vec_char = self.char_encoder.forward(question_emb_char, question_char_mask, question_mask)

        context_encode = torch.cat((context_encode, context_vec_char), dim=-1)
        question_encode = torch.cat((question_encode, question_vec_char), dim=-1)

        # match lstm: (seq_len, batch, hidden_size)
        qt_aware_ct, qt_aware_last_hidden, match_para = self.match_rnn.forward(context_encode, context_mask,
                                                                               question_encode, question_mask)
        vis_param = {'match': match_para}

        # birnn after self match: (seq_len, batch, hidden_size)
        qt_aware_ct_ag, _ = self.birnn_after_self.forward(qt_aware_ct, context_mask)

        # pointer net init hidden: (batch, hidden_size)
        ptr_net_hidden = F.tanh(self.init_ptr_hidden.forward(qt_aware_last_hidden))

        # pointer net: (answer_len, batch, context_len)
        ans_range_prop = self.pointer_net.forward(qt_aware_ct_ag, context_mask, ptr_net_hidden)
        ans_range_prop = ans_range_prop.transpose(0, 1)

        # answer range
        if not self.training and self.enable_search:
            ans_range = answer_search(ans_range_prop, context_mask)
        else:
            _, ans_range = torch.max(ans_range_prop, dim=2)

        return ans_range_prop, ans_range, vis_param
예제 #6
0
    def forward(self,
                context,
                question,
                context_char=None,
                question_char=None,
                context_f=None,
                question_f=None):
        assert context_char is not None and question_char is not None and context_f is not None \
               and question_f is not None

        vis_param = {}

        # (seq_len, batch, additional_feature_size)
        context_f = context_f.transpose(0, 1)
        question_f = question_f.transpose(0, 1)

        # word-level embedding: (seq_len, batch, word_embedding_size)
        context_vec, context_mask = self.embedding.forward(context)
        question_vec, question_mask = self.embedding.forward(question)

        # char-level embedding: (seq_len, batch, char_embedding_size)
        context_emb_char, context_char_mask = self.char_embedding.forward(
            context_char)
        question_emb_char, question_char_mask = self.char_embedding.forward(
            question_char)

        context_vec_char = self.char_encoder.forward(context_emb_char,
                                                     context_char_mask,
                                                     context_mask)
        question_vec_char = self.char_encoder.forward(question_emb_char,
                                                      question_char_mask,
                                                      question_mask)

        # mix embedding: (seq_len, batch, embedding_size)
        context_vec = torch.cat((context_vec, context_vec_char, context_f),
                                dim=-1)
        question_vec = torch.cat((question_vec, question_vec_char, question_f),
                                 dim=-1)

        # encode: (seq_len, batch, hidden_size*2)
        context_encode, _ = self.encoder.forward(context_vec, context_mask)
        question_encode, zs = self.encoder.forward(question_vec, question_mask)

        align_ct = context_encode
        for i in range(self.num_align_hops):
            # align: (seq_len, batch, hidden_size*2)
            qt_align_ct, alpha = self.aligner[i](align_ct, question_encode,
                                                 question_mask)
            bar_ct = self.aligner_sfu[i](
                align_ct,
                torch.cat([
                    qt_align_ct, align_ct * qt_align_ct, align_ct - qt_align_ct
                ],
                          dim=-1))
            vis_param['match'] = alpha

            # self-align: (seq_len, batch, hidden_size*2)
            ct_align_ct, self_alpha = self.self_aligner[i](bar_ct,
                                                           context_mask)
            hat_ct = self.self_aligner_sfu[i](
                bar_ct,
                torch.cat(
                    [ct_align_ct, bar_ct * ct_align_ct, bar_ct - ct_align_ct],
                    dim=-1))
            vis_param['self-match'] = self_alpha

            # aggregation: (seq_len, batch, hidden_size*2)
            align_ct, _ = self.aggregation[i](hat_ct, context_mask)

        # pointer net: (answer_len, batch, context_len)
        for i in range(self.num_ptr_hops):
            ans_range_prop, zs = self.ptr_net[i](align_ct, context_mask, zs)

        # answer range
        ans_range_prop = ans_range_prop.transpose(0, 1)
        if not self.training and self.enable_search:
            ans_range = answer_search(ans_range_prop, context_mask)
        else:
            _, ans_range = torch.max(ans_range_prop, dim=2)

        return ans_range_prop, ans_range, vis_param
예제 #7
0
    def forward(self,
                context,
                question,
                context_char=None,
                question_char=None,
                context_f=None,
                question_f=None):
        if self.enable_char:
            assert context_char is not None and question_char is not None

        if self.enable_features:
            assert context_f is not None and question_f is not None

        # get embedding: (seq_len, batch, embedding_size)
        context_vec, context_mask = self.embedding.forward(context)
        question_vec, question_mask = self.embedding.forward(question)

        if self.enable_features:
            assert context_f is not None and question_f is not None

            # (seq_len, batch, additional_feature_size)
            context_f = context_f.transpose(0, 1)
            question_f = question_f.transpose(0, 1)

            context_vec = torch.cat([context_vec, context_f], dim=-1)
            question_vec = torch.cat([question_vec, question_f], dim=-1)

        # char-level embedding: (seq_len, batch, char_embedding_size)
        if self.enable_char:
            context_emb_char, context_char_mask = self.char_embedding.forward(
                context_char)
            question_emb_char, question_char_mask = self.char_embedding.forward(
                question_char)

            context_vec_char = self.char_encoder.forward(
                context_emb_char, context_char_mask, context_mask)
            question_vec_char = self.char_encoder.forward(
                question_emb_char, question_char_mask, question_mask)

            if self.mix_encode:
                context_vec = torch.cat((context_vec, context_vec_char),
                                        dim=-1)
                question_vec = torch.cat((question_vec, question_vec_char),
                                         dim=-1)

        # encode: (seq_len, batch, hidden_size)
        context_encode, _ = self.encoder.forward(context_vec, context_mask)
        question_encode, _ = self.encoder.forward(question_vec, question_mask)

        # char-level encode: (seq_len, batch, hidden_size)
        if self.enable_char and not self.mix_encode:
            context_encode = torch.cat((context_encode, context_vec_char),
                                       dim=-1)
            question_encode = torch.cat((question_encode, question_vec_char),
                                        dim=-1)

        # question match-lstm
        match_rnn_in_question = question_encode
        if self.enable_question_match:
            ct_aware_qt, _, _ = self.question_match_rnn.forward(
                question_encode, question_mask, context_encode, context_mask)
            match_rnn_in_question = ct_aware_qt

        # match lstm: (seq_len, batch, hidden_size)
        qt_aware_ct, qt_aware_last_hidden, match_para = self.match_rnn.forward(
            context_encode, context_mask, match_rnn_in_question, question_mask)
        vis_param = {'match': match_para}

        # self match lstm: (seq_len, batch, hidden_size)
        if self.enable_self_match:
            qt_aware_ct, qt_aware_last_hidden, self_para = self.self_match_rnn.forward(
                qt_aware_ct, context_mask, qt_aware_ct, context_mask)
            vis_param['self'] = self_para

        # birnn after self match: (seq_len, batch, hidden_size)
        if self.enable_birnn_after_self:
            qt_aware_ct, _ = self.birnn_after_self.forward(
                qt_aware_ct, context_mask)

        # self gated
        if self.enable_self_gated:
            qt_aware_ct = self.self_gated(qt_aware_ct)

        # pointer net init hidden: (batch, hidden_size)
        ptr_net_hidden = None
        if self.init_ptr_hidden_mode == 'pooling':
            ptr_net_hidden = self.init_ptr_hidden.forward(
                question_encode, question_mask)
        elif self.init_ptr_hidden_mode == 'linear':
            ptr_net_hidden = self.init_ptr_hidden.forward(qt_aware_last_hidden)
            ptr_net_hidden = F.tanh(ptr_net_hidden)

        # pointer net: (answer_len, batch, context_len)
        # ans_range_prop = self.pointer_net.forward(qt_aware_ct, context_mask, ptr_net_hidden)
        # ans_range_prop = ans_range_prop.transpose(0, 1)

        ans_range_prop = multi_scale_ptr(self.pointer_net, ptr_net_hidden,
                                         qt_aware_ct, context_mask,
                                         self.scales)

        # answer range
        if not self.training and self.enable_search:
            ans_range = answer_search(ans_range_prop, context_mask)
        else:
            _, ans_range = torch.max(ans_range_prop, dim=2)

        return ans_range_prop, ans_range, vis_param
예제 #8
0
    def forward(self, context, question, context_char=None, question_char=None, context_f=None, question_f=None):
        if self.enable_char:
            assert context_char is not None and question_char is not None

        if self.enable_features:
            assert context_f is not None and question_f is not None

        # get embedding: (seq_len, batch, embedding_size)
        context_vec, context_mask = self.embedding.forward(context)
        question_vec, question_mask = self.embedding.forward(question)

        if self.enable_features:
            assert context_f is not None and question_f is not None

            # (seq_len, batch, additional_feature_size)
            context_f = context_f.transpose(0, 1)
            question_f = question_f.transpose(0, 1)

            context_vec = torch.cat([context_vec, context_f], dim=-1)
            question_vec = torch.cat([question_vec, question_f], dim=-1)

        # char-level embedding: (seq_len, batch, char_embedding_size)
        if self.enable_char:
            context_emb_char, context_char_mask = self.char_embedding.forward(context_char)
            question_emb_char, question_char_mask = self.char_embedding.forward(question_char)

            context_vec_char = self.char_encoder.forward(context_emb_char, context_char_mask, context_mask)
            question_vec_char = self.char_encoder.forward(question_emb_char, question_char_mask, question_mask)

            if self.mix_encode:
                context_vec = torch.cat((context_vec, context_vec_char), dim=-1)
                question_vec = torch.cat((question_vec, question_vec_char), dim=-1)

        # encode: (seq_len, batch, hidden_size)
        context_encode, _ = self.encoder.forward(context_vec, context_mask)
        question_encode, _ = self.encoder.forward(question_vec, question_mask)

        # char-level encode: (seq_len, batch, hidden_size)
        if self.enable_char and not self.mix_encode:
            context_encode = torch.cat((context_encode, context_vec_char), dim=-1)
            question_encode = torch.cat((question_encode, question_vec_char), dim=-1)

        # question match-lstm
        match_rnn_in_question = question_encode
        if self.enable_question_match:
            ct_aware_qt, _, _ = self.question_match_rnn.forward(question_encode, question_mask,
                                                                context_encode, context_mask)
            match_rnn_in_question = ct_aware_qt

        # match lstm: (seq_len, batch, hidden_size)
        qt_aware_ct, qt_aware_last_hidden, match_para = self.match_rnn.forward(context_encode, context_mask,
                                                                               match_rnn_in_question, question_mask)
        vis_param = {'match': match_para}

        # self match lstm: (seq_len, batch, hidden_size)
        if self.enable_self_match:
            qt_aware_ct, qt_aware_last_hidden, self_para = self.self_match_rnn.forward(qt_aware_ct, context_mask,
                                                                                       qt_aware_ct, context_mask)
            vis_param['self'] = self_para

        # birnn after self match: (seq_len, batch, hidden_size)
        if self.enable_birnn_after_self:
            qt_aware_ct, _ = self.birnn_after_self.forward(qt_aware_ct, context_mask)

        # self gated
        if self.enable_self_gated:
            qt_aware_ct = self.self_gated(qt_aware_ct)

        # pointer net init hidden: (batch, hidden_size)
        ptr_net_hidden = None
        if self.init_ptr_hidden_mode == 'pooling':
            ptr_net_hidden = self.init_ptr_hidden.forward(question_encode, question_mask)
        elif self.init_ptr_hidden_mode == 'linear':
            ptr_net_hidden = self.init_ptr_hidden.forward(qt_aware_last_hidden)
            ptr_net_hidden = F.tanh(ptr_net_hidden)

        # pointer net: (answer_len, batch, context_len)
        # ans_range_prop = self.pointer_net.forward(qt_aware_ct, context_mask, ptr_net_hidden)
        # ans_range_prop = ans_range_prop.transpose(0, 1)

        ans_range_prop = multi_scale_ptr(self.pointer_net, ptr_net_hidden, qt_aware_ct, context_mask, self.scales)

        # answer range
        if not self.training and self.enable_search:
            ans_range = answer_search(ans_range_prop, context_mask)
        else:
            _, ans_range = torch.max(ans_range_prop, dim=2)

        return ans_range_prop, ans_range, vis_param
예제 #9
0
    def forward(self,
                context,
                question,
                context_char=None,
                question_char=None):
        if self.enable_char:
            assert context_char is not None and question_char is not None

        # get embedding: (seq_len, batch, embedding_size)
        context_vec, context_mask = self.embedding.forward(context)
        question_vec, question_mask = self.embedding.forward(question)

        # char-level embedding: (seq_len, batch, char_embedding_size)
        if self.enable_char:
            context_emb_char, context_char_mask = self.char_embedding.forward(
                context_char)
            question_emb_char, question_char_mask = self.char_embedding.forward(
                question_char)

            context_vec_char = self.char_encoder.forward(
                context_emb_char, context_char_mask, context_mask)
            question_vec_char = self.char_encoder.forward(
                question_emb_char, question_char_mask, question_mask)

            if self.mix_encode:
                context_vec = torch.cat((context_vec, context_vec_char),
                                        dim=-1)
                question_vec = torch.cat((question_vec, question_vec_char),
                                         dim=-1)

        # encode: (seq_len, batch, hidden_size)
        context_encode, _ = self.encoder.forward(context_vec, context_mask)
        question_encode, _ = self.encoder.forward(question_vec, question_mask)

        # char-level encode: (seq_len, batch, hidden_size)
        if self.enable_char and not self.mix_encode:
            context_encode = torch.cat((context_encode, context_vec_char),
                                       dim=-1)
            question_encode = torch.cat((question_encode, question_vec_char),
                                        dim=-1)

        # match lstm: (seq_len, batch, hidden_size)
        qt_aware_ct, qt_aware_last_hidden, match_alpha = self.match_rnn.forward(
            context_encode, context_mask, question_encode, question_mask)
        vis_param = {'match': match_alpha}

        # self match lstm: (seq_len, batch, hidden_size)
        if self.enable_self_match:
            qt_aware_ct, qt_aware_last_hidden, self_alpha = self.self_match_rnn.forward(
                qt_aware_ct, context_mask, qt_aware_ct, context_mask)
            vis_param['self'] = self_alpha

        # birnn after self match: (seq_len, batch, hidden_size)
        if self.enable_birnn_after_self:
            qt_aware_ct, _ = self.birnn_after_self.forward(
                qt_aware_ct, context_mask)

        # pointer net init hidden: (batch, hidden_size)
        ptr_net_hidden = None
        if self.init_ptr_hidden_mode == 'pooling':
            ptr_net_hidden = self.init_ptr_hidden.forward(
                question_encode, question_mask)
        elif self.init_ptr_hidden_mode == 'linear':
            ptr_net_hidden = self.init_ptr_hidden.forward(qt_aware_last_hidden)
            ptr_net_hidden = F.tanh(ptr_net_hidden)

        # pointer net: (answer_len, batch, context_len)
        ans_range_prop = self.pointer_net.forward(qt_aware_ct, context_mask,
                                                  ptr_net_hidden)
        ans_range_prop = ans_range_prop.transpose(0, 1)

        # answer range
        if self.enable_search:
            ans_range = answer_search(ans_range_prop, context_mask)
        else:
            ans_range = torch.max(ans_range_prop, 2)[1]

        return ans_range_prop, ans_range, vis_param