Exemplo n.º 1
0
 def getEmb(self, indexes, src_enc):
     """
     # index:  rel_index_batch:  list(batch, real_gold_amr_len), but content is the index of recatogrized amr index, is a mapping
     # src_enc: DoublePackedSequence(packed: packed_g_amr_len x max_re_amr_len x dim), length= re_lens), out_l = g_amr_len
     """
     head_emb, dep_emb = [], []
     # unpacked_src_enc: list(batch, re_g_real_len x max_re_amr_len x dim )
     unpacked_src_enc, _ = doubleunpack(src_enc)
     length_pairs = []
     for i, index in enumerate(
             indexes):  # indexes : list(batch, real_g_amr_len)
         enc = unpacked_src_enc[i]  #g_amr_real_len x  max_re_amr_l x dim
         dep_emb.append(enc.index_select(
             1,
             index))  #list(batch_size, real_g_amr_l x real_g_amr_l x dim)
         # head_index : [real_g_amr_len, 1, 1] -> [real_g_amr_len, 1, dim]
         head_index = index.unsqueeze(1).unsqueeze(2).expand(
             enc.size(0), 1, enc.size(-1))
         # logger.info("enc: {}, head_idnex :{}".format(enc, head_index))
         head_emb.append(enc.gather(
             1, head_index).squeeze(1))  # list(batch, real_g_amr_l x dim)
         length_pairs.append([len(index), len(index)])
     # head_emb_t :  MyPackedSequence(data: packed_real_g_amr_l x dim), g_amr_l
     # dep_emb_t :  MyDoublePackedSequence(PackedSequenceLength(packed_real_g_amr_l x real_g_amr_l x dim), length_pairs)
     # length_pairs :(g_amr_l, g_amr_l)
     return mypack(head_emb, [ls[0] for ls in length_pairs]), mydoublepack(
         dep_emb, length_pairs), length_pairs
Exemplo n.º 2
0
    def forward(self, input, index, src_enc):
        """
        # input: relBatch: packed_gold_amr_lengths x n_feature,  AMR_CAT, AMR_LE, AMR_NER, AMR_SENSE, index of nodes,
        #  mypacked_seq[packed_batch_gold_amr_len x tgt_feature_dim]
        # index:  rel_index_batch:  list(batch, real_gold_amr_len), but content is the index of recatogrized amr index, is a mapping
        # src_enc: DoublePackedSequence(packed: packed_g_amr_len x re_amr_len x dim), length= re_lens)
        """
        assert isinstance(input, MyPackedSequence), input
        # lengths: real_gold_amr_lens
        # after unpack, input is packed_gold_amr_lengths x n_features
        input, lengths = input
        if self.alpha_dropout and self.training:
            input = data_dropout(input, self.alpha_dropout)

        psd_target_pos_embed = self.psd_target_pos_lut(input[:, PSD_POS])
        #psd_sense_embed = self.psd_sense_lut(input[:,PSD_SENSE])
        #psd_lemma_embed = self.lemma_lut(input[:,PSD_LE])

        #psd_emb = torch.cat([psd_target_pos_embed, psd_sense_embed,psd_lemma_embed],1)
        #psd_emb = torch.cat([psd_target_pos_embed, psd_sense_embed],1)
        psd_emb = torch.cat([psd_target_pos_embed], 1)

        # head_emb_t :  MyPackedSequence(data: packed_real_g_amr_l x dim), g_amr_l
        # dep_emb_t :  MyDoublePackedSequence(PackedSequenceLength(packed_real_g_amr_l x real_g_amr_l x dim), length_pairs)
        # length_pairs :(g_amr_l, g_amr_l)
        head_emb_t, dep_emb_t, length_pairs = self.getEmb(
            index, src_enc)  #packed, mydoublepacked

        head_emb = torch.cat([psd_emb, head_emb_t.data], 1)

        dep_psd_emb_t = myunpack(*MyPackedSequence(psd_emb, lengths))
        dep_psd_emb = [
            emb.unsqueeze(0).expand(emb.size(0), emb.size(0), emb.size(-1))
            for emb in dep_psd_emb_t
        ]

        mydouble_psd_emb = mydoublepack(dep_psd_emb, length_pairs)

        dep_emb = torch.cat([mydouble_psd_emb.data, dep_emb_t.data], -1)

        # emb_unpacked = myunpack(emb,lengths)
        assert head_emb.size(
            -1) == self.inputSize, "wrong head  size {}".format(
                head_emb.size())
        # head_packed :  MyPackedSequence(data: packed_real_g_amr_l x rel_dim), g_amr_l
        head_packed = MyPackedSequence(self.head(head_emb),
                                       lengths)  #  total,rel_dim
        head_psd_packed = MyPackedSequence(psd_emb, lengths)  #  total,rel_dim

        size = dep_emb.size()
        assert dep_emb.size(-1) == self.inputSize, "wrong dep size {}".format(
            dep_emb.size())
        dep = self.dep(dep_emb.view(-1, size[-1])).view(size[0], size[1], -1)

        # dep_emb_t :  MyDoublePackedSequence(PackedSequenceLength(packed_real_g_amr_l x real_g_amr_l x rel_dim), length_pairs)
        dep_packed = MyDoublePackedSequence(
            MyPackedSequence(dep, mydouble_psd_emb[0][1]), mydouble_psd_emb[1],
            dep)

        return head_psd_packed, head_packed, dep_packed  #,MyPackedSequence(emb,lengths)
Exemplo n.º 3
0
 def getEmb(self,indexes,src_enc):
     head_emb,dep_emb = [],[]
     src_enc,src_l = doubleunpack(src_enc)  # batch x var(amr_l x src_l x dim)
     length_pairs = []
     for i, index in enumerate(indexes):
         enc = src_enc[i]  #amr_l src_l dim
         dep_emb.append(enc.index_select(1,index))  #var(amr_l x amr_l x dim)
         head_index = index.unsqueeze(1).unsqueeze(2).expand(enc.size(0),1,enc.size(-1))
    #     print ("getEmb",enc.size(),dep_index.size(),head_index.size())
         head_emb.append(enc.gather(1,head_index).squeeze(1))  #var(amr_l  x dim)
         length_pairs.append([len(index),len(index)])
     return mypack(head_emb,[ls[0] for ls in length_pairs]),mydoublepack(dep_emb,length_pairs),length_pairs
Exemplo n.º 4
0
    def forward(self, input, index,src_enc):
        assert isinstance(input, MyPackedSequence),input
        input,lengths = input
        if self.alpha and self.training:
            input = data_dropout(input,self.alpha)
        cat_embed = self.cat_lut(input[:,AMR_CAT])
        lemma_embed = self.lemma_lut(input[:,AMR_LE])

        amr_emb = torch.cat([cat_embed,lemma_embed],1)
    #    print (input,lengths)

        head_emb_t,dep_emb_t,length_pairs = self.getEmb(index,src_enc)  #packed, mydoublepacked


        head_emb = torch.cat([amr_emb,head_emb_t.data],1)

        dep_amr_emb_t = myunpack(*MyPackedSequence(amr_emb,lengths))
        dep_amr_emb = [ emb.unsqueeze(0).expand(emb.size(0),emb.size(0),emb.size(-1))      for emb in dep_amr_emb_t]

        mydouble_amr_emb = mydoublepack(dep_amr_emb,length_pairs)

    #    print ("rel_encoder",mydouble_amr_emb.data.size(),dep_emb_t.data.size())
        dep_emb = torch.cat([mydouble_amr_emb.data,dep_emb_t.data],-1)

       # emb_unpacked = myunpack(emb,lengths)

        head_packed = MyPackedSequence(self.head(head_emb),lengths) #  total,rel_dim
        head_amr_packed = MyPackedSequence(amr_emb,lengths) #  total,rel_dim

   #     print ("dep_emb",dep_emb.size())
        size = dep_emb.size()
        dep = self.dep(dep_emb.view(-1,size[-1])).view(size[0],size[1],-1)

        dep_packed  = MyDoublePackedSequence(MyPackedSequence(dep,mydouble_amr_emb[0][1]),mydouble_amr_emb[1],dep)

        return  head_amr_packed,head_packed,dep_packed  #,MyPackedSequence(emb,lengths)