Example #1
0
    def forward(self, p, p_pos, p_ner, p_mask, q, q_pos, q_ner, q_mask, c,c_pos,c_ner, c_mask,\
               p_f_tensor,q_f_tensor,c_f_tensor, p_q_relation, p_c_relation,q_p_relation,q_c_relation,c_p_relation,c_q_relation,is_paint=0):

        p_rnn_input, q_rnn_input, c_rnn_input = self.add_embeddings(p, p_pos, p_ner, q, q_pos, q_ner, c,c_pos,c_ner,p_f_tensor,q_f_tensor,c_f_tensor,\
                       p_q_relation, p_c_relation,q_p_relation,q_c_relation,c_p_relation,c_q_relation)

        p_hiddens = self.context_rnn(p_rnn_input, p_mask)
        q_hiddens = self.context_rnn(q_rnn_input, q_mask)
        c_hiddens = self.context_rnn(c_rnn_input, c_mask)
        if self.args.dropout_rnn_output > 0:
            p_hiddens = nn.functional.dropout(p_hiddens,
                                              p=self.args.dropout_rnn_output,
                                              training=self.training)
            q_hiddens = nn.functional.dropout(q_hiddens,
                                              p=self.args.dropout_rnn_output,
                                              training=self.training)
            c_hiddens = nn.functional.dropout(c_hiddens,
                                              p=self.args.dropout_rnn_output,
                                              training=self.training)

        ####################################################
        self.mfunction(p_hiddens, q_hiddens, c_hiddens, p_mask, q_mask, c_mask)
        '''
        if self.args.tri_input == 'NA':
             self.NA_TriMatching(p_hiddens, q_hiddens, c_hiddens,p_mask, q_mask, c_mask)

        elif self.args.tri_input == 'CA':
            self.CA_TriMatching(p_hiddens, q_hiddens, c_hiddens,p_mask, q_mask, c_mask)

        else:
            self.NA_CA_TriMatching(p_hiddens, q_hiddens, c_hiddens,p_mask, q_mask, c_mask)
        '''
        #------output-layer--------------

        _, matched_q_self = self.q_self_attn(self.matched_q, q_mask)
        _, matched_p_self = self.q_self_attn(self.matched_p, p_mask)
        _, matched_c_self = self.q_self_attn(self.matched_c, c_mask)

        p_infer_hidden_ave = layers.ave_pooling(self.p_infer_emb, p_mask)
        p_infer_hidden_max = layers.max_pooling(self.p_infer_emb)

        q_infer_hidden_ave = layers.ave_pooling(self.q_infer_emb, q_mask)
        q_infer_hidden_max = layers.max_pooling(self.q_infer_emb)

        c_infer_hidden_ave = layers.ave_pooling(self.c_infer_emb, c_mask)
        c_infer_hidden_max = layers.max_pooling(self.c_infer_emb)

        #import pdb
        #pdb.set_trace()
        infer_linear = self.c_infer_linear(torch.cat([p_infer_hidden_ave,p_infer_hidden_max,\
                                                      q_infer_hidden_ave,q_infer_hidden_max,\
                                                      c_infer_hidden_ave,c_infer_hidden_max,\
                                                      matched_q_self,matched_p_self, matched_c_self],-1))

        logits = self.logits_linear(infer_linear)
        proba = F.sigmoid(logits.squeeze(1))

        return proba
    def forward(self, p, p_pos, p_ner, p_mask, q, q_pos, q_ner, q_mask, c,c_pos,c_ner, c_mask,\
               p_f_tensor,q_f_tensor,c_f_tensor, p_q_relation, p_c_relation,q_p_relation,q_c_relation,c_p_relation,c_q_relation,is_paint=0):

        p_emb, q_emb, c_emb = self.embedding(p), self.embedding(
            q), self.embedding(c)
        p_pos_emb, q_pos_emb, c_pos_emb = self.pos_embedding(
            p_pos), self.pos_embedding(q_pos), self.pos_embedding(c_pos)
        p_ner_emb, q_ner_emb, c_ner_emb = self.ner_embedding(
            p_ner), self.ner_embedding(q_ner), self.ner_embedding(c_ner)
        p_q_rel_emb, p_c_rel_emb = self.rel_embedding(
            p_q_relation), self.rel_embedding(p_c_relation)
        q_p_rel_emb, q_c_rel_emb = self.rel_embedding(
            q_p_relation), self.rel_embedding(q_c_relation)
        c_p_rel_emb, c_q_rel_emb = self.rel_embedding(
            c_p_relation), self.rel_embedding(c_q_relation)

        # Dropout on embeddings
        if self.args.dropout_emb > 0:
            p_emb = nn.functional.dropout(p_emb,
                                          p=self.args.dropout_emb,
                                          training=self.training)
            q_emb = nn.functional.dropout(q_emb,
                                          p=self.args.dropout_emb,
                                          training=self.training)
            c_emb = nn.functional.dropout(c_emb,
                                          p=self.args.dropout_emb,
                                          training=self.training)

            p_pos_emb = nn.functional.dropout(p_pos_emb,
                                              p=self.args.dropout_emb,
                                              training=self.training)
            q_pos_emb = nn.functional.dropout(q_pos_emb,
                                              p=self.args.dropout_emb,
                                              training=self.training)
            c_pos_emb = nn.functional.dropout(c_pos_emb,
                                              p=self.args.dropout_emb,
                                              training=self.training)

            p_ner_emb = nn.functional.dropout(p_ner_emb,
                                              p=self.args.dropout_emb,
                                              training=self.training)
            q_ner_emb = nn.functional.dropout(q_ner_emb,
                                              p=self.args.dropout_emb,
                                              training=self.training)
            c_ner_emb = nn.functional.dropout(c_ner_emb,
                                              p=self.args.dropout_emb,
                                              training=self.training)
        #_,q_weighted_emb = self.q_emb_match(q_emb, q_emb, q_mask)
        #_,q_c_weighted_emb = self.q_emb_match(q_emb, c_emb, c_mask)
        #_,c_p_weighted_emb = self.c_emb_match(c_emb, p_emb, p_mask)
        #_,c_weighted_emb = self.c_emb_match(c_emb, c_emb, c_mask)

        #if self.args.dropout_emb > 0:
        #    q_weighted_emb = nn.functional.dropout(q_weighted_emb, p=self.args.dropout_emb, training=self.training)
        #q_c_weighted_emb = nn.functional.dropout(q_c_weighted_emb, p=self.args.dropout_emb, training=self.training)
        #c_p_weighted_emb = nn.functional.dropout(c_p_weighted_emb, p=self.args.dropout_emb, training=self.training)
        #    c_weighted_emb = nn.functional.dropout(c_weighted_emb, p=self.args.dropout_emb, training=self.training)
        p_rnn_input = torch.cat([
            p_emb, p_pos_emb, p_ner_emb, p_f_tensor, p_q_rel_emb, p_c_rel_emb
        ], 2)
        q_rnn_input = torch.cat([
            q_emb, q_pos_emb, q_ner_emb, q_f_tensor, q_p_rel_emb, q_c_rel_emb
        ], 2)
        c_rnn_input = torch.cat([
            c_emb, c_pos_emb, c_ner_emb, c_f_tensor, c_p_rel_emb, c_q_rel_emb
        ], 2)

        p_hiddens = self.context_rnn(p_rnn_input, p_mask)
        q_hiddens = self.context_rnn(q_rnn_input, q_mask)
        c_hiddens = self.context_rnn(c_rnn_input, c_mask)
        #if self.args.dropout_rnn_output > 0:
        #    p_hiddens = nn.functional.dropout(p_hiddens, p=self.args.dropout_rnn_output, training=self.training)
        #    q_hiddens = nn.functional.dropout(q_hiddens, p=self.args.dropout_rnn_output, training=self.training)
        #    c_hiddens = nn.functional.dropout(c_hiddens, p=self.args.dropout_rnn_output, training=self.training)

        ####################################################
        #------q_p--------------
        _, q_p_weighted_hiddens = self.hidden_match(q_hiddens, p_hiddens,
                                                    p_mask)
        q_p_cat = torch.cat([q_hiddens, q_p_weighted_hiddens], 2)
        q_p_cat_weight, q_p_cat_weighted_hiddens = self.hidden_match(
            q_p_cat, q_p_cat, q_mask)
        if self.args.dropout_att_score > 0:
            q_p_cat_weight = nn.functional.dropout(
                q_p_cat_weight,
                p=self.args.dropout_att_score,
                training=self.training)
        matched_q = q_p_cat_weight.bmm(q_hiddens)

        #------p_c_q--------------
        _, p_c_weighted_hiddens = self.hidden_match(p_hiddens, c_hiddens,
                                                    c_mask)
        _, p_q_weighted_hiddens = self.hidden_match(p_hiddens, q_hiddens,
                                                    q_mask)

        p_cq_cat = torch.cat(
            [p_hiddens, p_c_weighted_hiddens, p_q_weighted_hiddens], 2)
        p_cq_cat_weight, p_cq_cat_weighted_hiddens = self.hidden_match(
            p_cq_cat, p_cq_cat, p_mask)
        if self.args.dropout_att_score > 0:
            p_cq_cat_weight = nn.functional.dropout(
                p_cq_cat_weight,
                p=self.args.dropout_att_score,
                training=self.training)
        matched_p = p_cq_cat_weight.bmm(p_hiddens)

        #------c_p_q--------------
        _, c_p_weighted_hiddens = self.hidden_match(c_hiddens, p_hiddens,
                                                    p_mask)
        _, c_q_weighted_hiddens = self.hidden_match(c_hiddens, q_hiddens,
                                                    q_mask)
        concat_feature = torch.cat(
            [c_hiddens, c_q_weighted_hiddens, c_p_weighted_hiddens], 2)
        sub_feature = (c_hiddens - c_q_weighted_hiddens) * (
            c_hiddens - c_p_weighted_hiddens)
        mul_feature = c_hiddens * c_q_weighted_hiddens * c_p_weighted_hiddens

        c_mfeature = {"c": concat_feature, "s": sub_feature, "m": mul_feature}
        dim = c_hiddens.size()
        init_mem = torch.zeros(dim[0], dim[1],
                               dim[2]).float().cuda()  #zero mem
        c_infer_emb, self.mem_list, self.mem_gate_list = self.mtinfer(
            c_mfeature,
            c_mask,
            init_mem=init_mem,
            x_order=self.args.matching_order)

        _, matched_q_self = self.q_self_attn(matched_q, q_mask)
        _, matched_p_self = self.q_self_attn(matched_p, p_mask)
        c_infer_hidden_ave = layers.ave_pooling(c_infer_emb, c_mask)
        c_infer_hidden_max = layers.max_pooling(c_infer_emb)

        #import pdb
        #pdb.set_trace()
        infer_linear = self.c_infer_linear(
            torch.cat([
                c_infer_hidden_ave, c_infer_hidden_max, matched_q_self,
                matched_p_self
            ], -1))

        logits = self.logits_linear(infer_linear)
        proba = F.sigmoid(logits.squeeze(1))

        return proba
Example #3
0
    def forward(self, p, p_pos, p_ner, p_mask, q, q_pos, q_ner, q_mask, c,c_pos,c_ner, c_mask,\
               p_f_tensor,q_f_tensor,c_f_tensor, p_q_relation, p_c_relation,q_p_relation,q_c_relation,c_p_relation,c_q_relation,is_paint=0):

        p_rnn_input, q_rnn_input, c_rnn_input = self.add_embeddings(p, p_pos, p_ner, q, q_pos, q_ner, c,c_pos,c_ner,p_f_tensor,q_f_tensor,c_f_tensor,\
                       p_q_relation, p_c_relation,q_p_relation,q_c_relation,c_p_relation,c_q_relation)

        p_hiddens = self.context_rnn(p_rnn_input, p_mask)
        q_hiddens = self.context_rnn(q_rnn_input, q_mask)
        c_hiddens = self.context_rnn(c_rnn_input, c_mask)
        if self.args.dropout_rnn_output > 0:
            p_hiddens = nn.functional.dropout(p_hiddens,
                                              p=self.args.dropout_rnn_output,
                                              training=self.training)
            q_hiddens = nn.functional.dropout(q_hiddens,
                                              p=self.args.dropout_rnn_output,
                                              training=self.training)
            c_hiddens = nn.functional.dropout(c_hiddens,
                                              p=self.args.dropout_rnn_output,
                                              training=self.training)

        ####################################################
        self.mfunction(p_hiddens, c_hiddens, p_mask, c_mask)

        #------output-layer--------------

        #_,matched_q_self = self.q_self_attn(self.matched_q,q_mask)
        _, matched_p_self = self.q_self_attn(self.matched_p, p_mask)
        _, matched_c_self = self.q_self_attn(self.matched_c, c_mask)
        outputs = [matched_p_self, matched_c_self]

        if self.args.p_channel == True:
            p_infer_hidden_ave = layers.ave_pooling(self.p_infer_emb, p_mask)
            p_infer_hidden_max = layers.max_pooling(self.p_infer_emb)
            outputs.append(p_infer_hidden_ave)
            outputs.append(p_infer_hidden_max)
        '''
        if self.args.q_channel==True:
            q_infer_hidden_ave = layers.ave_pooling(self.q_infer_emb,q_mask)
            q_infer_hidden_max = layers.max_pooling(self.q_infer_emb)
            outputs.append(q_infer_hidden_ave)
            outputs.append(q_infer_hidden_max)
        '''
        if self.args.c_channel == True:
            c_infer_hidden_ave = layers.ave_pooling(self.c_infer_emb, c_mask)
            c_infer_hidden_max = layers.max_pooling(self.c_infer_emb)
            outputs.append(c_infer_hidden_ave)
            outputs.append(c_infer_hidden_max)

        #import pdb
        #pdb.set_trace()
        #infer_linear = self.c_infer_linear(torch.cat([p_infer_hidden_ave,p_infer_hidden_max,\
        #                                              q_infer_hidden_ave,q_infer_hidden_max,\
        #                                              c_infer_hidden_ave,c_infer_hidden_max,\
        #                                              matched_q_self,matched_p_self, matched_c_self],-1))
        infer_linear = self.c_infer_linear(torch.cat(outputs, -1))
        logits = self.logits_linear(infer_linear)  #[0.1,0.2,0.5]
        proba = F.sigmoid(logits.squeeze(1))
        print(infer_linear.size(), logits.size(), proba.size())
        #torch.Size([44, 250]) torch.Size([44, 1]) torch.Size([44])
        #proba = F.log_softmax(logits.squeeze(1))

        return proba
Example #4
0
    def forward(self, p, p_pos, p_ner, p_mask, q, q_pos, q_ner, q_mask, c,c_pos,c_ner, c_mask,\
               p_f_tensor,q_f_tensor,c_f_tensor, p_q_relation, p_c_relation,q_p_relation,q_c_relation,c_p_relation,c_q_relation,is_paint=0):

        self.p = p
        self.q = q
        self.c = c
        p_emb, q_emb, c_emb = self.embedding(p), self.embedding(
            q), self.embedding(c)
        p_pos_emb, q_pos_emb, c_pos_emb = self.pos_embedding(
            p_pos), self.pos_embedding(q_pos), self.pos_embedding(c_pos)
        p_ner_emb, q_ner_emb, c_ner_emb = self.ner_embedding(
            p_ner), self.ner_embedding(q_ner), self.ner_embedding(c_ner)
        p_q_rel_emb, p_c_rel_emb = self.rel_embedding(
            p_q_relation), self.rel_embedding(p_c_relation)
        q_p_rel_emb, q_c_rel_emb = self.rel_embedding(
            q_p_relation), self.rel_embedding(q_c_relation)
        c_p_rel_emb, c_q_rel_emb = self.rel_embedding(
            c_p_relation), self.rel_embedding(c_q_relation)

        # Dropout on embeddings
        if self.args.dropout_emb > 0:
            p_emb = nn.functional.dropout(p_emb,
                                          p=self.args.dropout_emb,
                                          training=self.training)
            q_emb = nn.functional.dropout(q_emb,
                                          p=self.args.dropout_emb,
                                          training=self.training)
            c_emb = nn.functional.dropout(c_emb,
                                          p=self.args.dropout_emb,
                                          training=self.training)

            p_pos_emb = nn.functional.dropout(p_pos_emb,
                                              p=self.args.dropout_emb,
                                              training=self.training)
            q_pos_emb = nn.functional.dropout(q_pos_emb,
                                              p=self.args.dropout_emb,
                                              training=self.training)
            c_pos_emb = nn.functional.dropout(c_pos_emb,
                                              p=self.args.dropout_emb,
                                              training=self.training)

            p_ner_emb = nn.functional.dropout(p_ner_emb,
                                              p=self.args.dropout_emb,
                                              training=self.training)
            q_ner_emb = nn.functional.dropout(q_ner_emb,
                                              p=self.args.dropout_emb,
                                              training=self.training)
            c_ner_emb = nn.functional.dropout(c_ner_emb,
                                              p=self.args.dropout_emb,
                                              training=self.training)

        #_,q_weighted_emb = self.q_emb_match(q_emb, q_emb, q_mask)
        #_,q_c_weighted_emb = self.q_emb_match(q_emb, c_emb, c_mask)
        #_,c_p_weighted_emb = self.c_emb_match(c_emb, p_emb, p_mask)
        #_,c_weighted_emb = self.c_emb_match(c_emb, c_emb, c_mask)

        #if self.args.dropout_emb > 0:
        #    q_weighted_emb = nn.functional.dropout(q_weighted_emb, p=self.args.dropout_emb, training=self.training)
        #q_c_weighted_emb = nn.functional.dropout(q_c_weighted_emb, p=self.args.dropout_emb, training=self.training)
        #c_p_weighted_emb = nn.functional.dropout(c_p_weighted_emb, p=self.args.dropout_emb, training=self.training)
        #    c_weighted_emb = nn.functional.dropout(c_weighted_emb, p=self.args.dropout_emb, training=self.training)
        p_rnn_input = torch.cat([
            p_emb, p_pos_emb, p_ner_emb, p_f_tensor, p_q_rel_emb, p_c_rel_emb
        ], 2)
        q_rnn_input = torch.cat([
            q_emb, q_pos_emb, q_ner_emb, q_f_tensor, q_p_rel_emb, q_c_rel_emb
        ], 2)
        c_rnn_input = torch.cat([
            c_emb, c_pos_emb, c_ner_emb, c_f_tensor, c_p_rel_emb, c_q_rel_emb
        ], 2)
        #q_rnn_input = torch.cat([q_emb,q_f_tensor],2)
        #c_rnn_input = torch.cat([c_emb,c_f_tensor],2)

        p_hiddens = self.context_rnn(p_rnn_input, p_mask)
        q_hiddens = self.context_rnn(q_rnn_input, q_mask)
        c_hiddens = self.context_rnn(c_rnn_input, c_mask)
        # print('p_hiddens', p_hiddens.size())

        ####################################################
        #c_p_weighted_hiddens = self.hidden_match(c_hiddens,p_hiddens,p_mask)
        _, c_q_weighted_hiddens = self.hidden_match(c_hiddens, q_hiddens,
                                                    q_mask)
        #------q_p--------------
        _, q_p_weighted_hiddens = self.hidden_match(q_hiddens, p_hiddens,
                                                    p_mask)
        q_p_cat = torch.cat([q_hiddens, q_p_weighted_hiddens], 2)
        q_p_cat_weight, q_p_cat_weighted_hiddens = self.hidden_match(
            q_p_cat, q_p_cat, q_mask)
        matched_q = q_p_cat_weight.bmm(q_hiddens)
        #q_q_cat = torch.cat([q_hiddens, matched_q])

        #------p_c_q--------------
        _, p_c_weighted_hiddens = self.hidden_match(p_hiddens, c_hiddens,
                                                    c_mask)
        _, p_q_weighted_hiddens = self.hidden_match(p_hiddens, q_hiddens,
                                                    q_mask)

        #p_cq_cat = torch.cat([p_hiddens,p_c_weighted_hiddens,p_q_weighted_hiddens],2)
        p_cq_cat = torch.cat([(p_hiddens - p_c_weighted_hiddens) *
                              (p_hiddens - p_q_weighted_hiddens)], 2)

        p_cq_cat_weight, p_cq_cat_weighted_hiddens = self.hidden_match(
            p_cq_cat, p_cq_cat, p_mask)

        matched_p = p_cq_cat_weight.bmm(p_hiddens)

        self.matched_p_self_weight, matched_p_self = self.q_self_attn(
            matched_p, p_mask)
        #if self.args.dropout_init_mem_emb > 0:
        #    matched_p_self = nn.functional.dropout(matched_p_self, p=self.args.dropout_init_mem_emb, training=self.training)
        self.c_weighted_matched_p_weight, c_weighted_matched_p = self.hidden_match(
            c_hiddens, matched_p, p_mask)
        #print(self.c_weighted_matched_p_weight)
        #_, q_slfp_weighted_hiddens = self.slfp_linear(x=q_hiddens,y= matched_p_self,x_mask =q_mask)
        #_,c_weighted_q_slfp = self.hidden_match(c_hiddens,q_slfp_weighted_hiddens ,q_mask)
        #print ("q_slfp_weighted_hiddens ",q_slfp_weighted_hiddens.size() )
        #------c_q--------------
        #concat_feature = torch.cat([c_hiddens,c_q_weighted_hiddens],2)
        #sub_feature =  (c_hiddens -c_q_weighted_hiddens)
        #mul_feature = c_hiddens*c_q_weighted_hiddens
        #concat_feature = torch.cat([c_hiddens,c_q_weighted_hiddens],2)
        concat_feature = torch.cat([c_hiddens, c_q_weighted_hiddens], 2)
        #concat_feature = c_hiddens+ c_q_weighted_hiddens
        sub_feature = (c_hiddens - c_q_weighted_hiddens)
        mul_feature = self.args.beta * c_hiddens * c_q_weighted_hiddens
        #mul_feature = c_hiddens*c_q_weighted_hiddens
        #mul_feature = c_hiddens+c_q_weighted_hiddens

        c_mfeature = {"c": concat_feature, "s": sub_feature, "m": mul_feature}
        #c_infer_emb = self.mtinfer(c_mfeature,c_mask,x_order=self.args.matching_order,init_mem=c_weighted_matched_p)
        dim = c_hiddens.size()
        #init_mem = torch.zeros(dim[0],dim[1],dim[2]).float().cuda()  #zero mem
        #init_mem = matched_p_self.unsqueeze(1).expand(c_hiddens.size())  #p_self mem
        init_mem = c_weighted_matched_p  #c_weighted_matched_p mem  ,best
        if self.args.dropout_init_mem_emb > 0:
            init_mem = nn.functional.dropout(init_mem,
                                             p=self.args.dropout_init_mem_emb,
                                             training=self.training)
        c_infer_emb, self.mem_list, self.mem_gate_list = self.mtinfer(
            c_mfeature,
            c_mask,
            init_mem=init_mem,
            x_order=self.args.matching_order)
        self.c_infer_emb = c_infer_emb
        #c_infer_emb = self.mtinfer(c_mfeature,c_mask,init_mem=c_weighted_matched_p,x_order=self.args.matching_order)

        #if self.args.dropout_emb > 0:
        #    c_infer_emb = nn.functional.dropout(c_infer_emb, p=self.args.dropout_emb, training=self.training)
        #c_infer_hidden_self = self.c_infer_self_attn(c_infer_emb,c_mask)
        #c_infer_hidden_self = self.q_self_attn(c_infer_emb,c_mask)
        self.matched_q_self_weight, matched_q_self = self.q_self_attn(
            matched_q, q_mask)
        #matched_q_ave = layers.ave_pooling(matched_q,q_mask)
        c_infer_hidden_ave = layers.ave_pooling(c_infer_emb, c_mask)
        c_infer_hidden_max = layers.max_pooling(c_infer_emb)

        #c_infer_hidden = self.c_infer_linear(torch.cat([c_infer_hidden_self, c_infer_hidden_ave, c_infer_hidden_max],-1))
        #logits = self.logits_linear(c_infer_hidden)
        #print ("c_infer_hidden_ave",c_infer_hidden_ave.size())
        #print ("c_infer_hidden_max",c_infer_hidden_max.size())
        #print ("matched_p_self",matched_p_self.size())
        #print ("matched_q_self",matched_q_self.size())
        infer_linear = self.c_infer_linear(
            torch.cat([
                c_infer_hidden_ave, c_infer_hidden_max, matched_p_self,
                matched_q_self
            ], -1))
        #infer_linear = self.c_infer_linear(torch.cat([c_infer_hidden_ave,c_infer_hidden_max,matched_p_self],-1))
        #infer_linear = self.c_infer_linear(torch.cat([c_infer_hidden_ave,c_infer_hidden_self,matched_p_self,matched_q_self],-1))
        #infer_linear = self.c_infer_linear(torch.cat([c_infer_hidden_self,matched_p_self,matched_q_self],-1))

        logits = self.logits_linear(infer_linear)
        proba = F.sigmoid(logits.squeeze(1))

        if is_paint == 1:
            self.paint_data()

        return proba
Example #5
0
    def forward(self, p, p_pos, p_ner, p_mask, q, q_pos, q_ner, q_mask, c,c_pos,c_ner, c_mask,\
               p_f_tensor,q_f_tensor,c_f_tensor, p_q_relation, p_c_relation,q_p_relation,q_c_relation,c_p_relation,c_q_relation,is_paint=0):

        p_rnn_input, q_rnn_input, c_rnn_input = self.add_embeddings(p, p_pos, p_ner, q, q_pos, q_ner, c,c_pos,c_ner,p_f_tensor,q_f_tensor,c_f_tensor,\
                       p_q_relation, p_c_relation,q_p_relation,q_c_relation,c_p_relation,c_q_relation)

        p_hiddens = self.context_rnn(p_rnn_input, p_mask)
        q_hiddens = self.context_rnn(q_rnn_input, q_mask)
        c_hiddens = self.context_rnn(c_rnn_input, c_mask)
        if self.args.dropout_rnn_output > 0:
            p_hiddens = nn.functional.dropout(p_hiddens,
                                              p=self.args.dropout_rnn_output,
                                              training=self.training)
            q_hiddens = nn.functional.dropout(q_hiddens,
                                              p=self.args.dropout_rnn_output,
                                              training=self.training)
            c_hiddens = nn.functional.dropout(c_hiddens,
                                              p=self.args.dropout_rnn_output,
                                              training=self.training)

        ####################################################
        #--------------naive attention
        _, p_q_weighted_hiddens = self.hidden_match(p_hiddens, q_hiddens,
                                                    q_mask)
        _, p_c_weighted_hiddens = self.hidden_match(p_hiddens, c_hiddens,
                                                    c_mask)

        _, q_p_weighted_hiddens = self.hidden_match(q_hiddens, p_hiddens,
                                                    p_mask)
        _, q_c_weighted_hiddens = self.hidden_match(q_hiddens, c_hiddens,
                                                    c_mask)

        _, c_p_weighted_hiddens = self.hidden_match(c_hiddens, p_hiddens,
                                                    p_mask)
        _, c_q_weighted_hiddens = self.hidden_match(c_hiddens, q_hiddens,
                                                    q_mask)

        #--------------compound attention
        c_q_p_weighted_hiddens = self.hidden_match(c_hiddens,
                                                   q_p_weighted_hiddens,
                                                   q_mask)
        q_c_p_weighted_hiddens = self.hidden_match(q_hiddens,
                                                   c_p_weighted_hiddens,
                                                   c_mask)

        p_c_q_weighted_hiddens = self.hidden_match(p_hiddens,
                                                   c_q_weighted_hiddens,
                                                   c_mask)
        c_p_q_weighted_hiddens = self.hidden_match(c_hiddens,
                                                   p_q_weighted_hiddens,
                                                   p_mask)

        p_q_c_weighted_hiddens = self.hidden_match(p_hiddens,
                                                   q_c_weighted_hiddens,
                                                   q_mask)
        q_p_c_weighted_hiddens = self.hidden_match(q_hiddens,
                                                   p_c_weighted_hiddens,
                                                   p_mask)

        #------p_c_q--------------
        p_infer_emb, p_mems, p_mem_gates = self.tri_matching(
            x=p_hiddens,
            x_y=p_q_weighted_hiddens,
            x_z=p_c_weighted_hiddens,
            agg_function=self.mtinfer,
            x_mask=p_mask)

        #------q_p_c--------------
        q_infer_emb, q_mems, q_mem_gates = self.tri_matching(
            x=q_hiddens,
            x_y=q_p_weighted_hiddens,
            x_z=q_c_weighted_hiddens,
            agg_function=self.mtinfer,
            x_mask=q_mask)

        #------c_p_q--------------
        c_infer_emb, c_mems, c_mem_gates = self.tri_matching(
            x=c_hiddens,
            x_y=c_p_weighted_hiddens,
            x_z=c_q_weighted_hiddens,
            agg_function=self.mtinfer,
            x_mask=c_mask)

        #------matched_self--------------
        matched_p = self.matched_self(x=p_hiddens,
                                      x_y=p_q_weighted_hiddens,
                                      x_z=p_c_weighted_hiddens,
                                      x_mask=p_mask)
        matched_q = self.matched_self(x=q_hiddens,
                                      x_y=q_p_weighted_hiddens,
                                      x_z=q_c_weighted_hiddens,
                                      x_mask=q_mask)
        matched_c = self.matched_self(x=c_hiddens,
                                      x_y=c_p_weighted_hiddens,
                                      x_z=c_q_weighted_hiddens,
                                      x_mask=c_mask)

        #------output-layer--------------
        _, matched_q_self = self.q_self_attn(matched_q, q_mask)
        _, matched_p_self = self.q_self_attn(matched_p, p_mask)
        _, matched_c_self = self.q_self_attn(matched_c, c_mask)

        p_infer_hidden_ave = layers.ave_pooling(p_infer_emb, p_mask)
        p_infer_hidden_max = layers.max_pooling(p_infer_emb)

        q_infer_hidden_ave = layers.ave_pooling(q_infer_emb, q_mask)
        q_infer_hidden_max = layers.max_pooling(q_infer_emb)

        c_infer_hidden_ave = layers.ave_pooling(c_infer_emb, c_mask)
        c_infer_hidden_max = layers.max_pooling(c_infer_emb)

        #import pdb
        #pdb.set_trace()
        infer_linear = self.c_infer_linear(torch.cat([p_infer_hidden_ave,p_infer_hidden_max,\
                                                      q_infer_hidden_ave,q_infer_hidden_max,\
                                                      c_infer_hidden_ave,c_infer_hidden_max,\
                                                      matched_q_self,matched_p_self, matched_c_self],-1))

        logits = self.logits_linear(infer_linear)
        proba = F.sigmoid(logits.squeeze(1))

        return proba