def getEmb(self, indexes, src_enc): """ # index: rel_index_batch: list(batch, real_gold_amr_len), but content is the index of recatogrized amr index, is a mapping # src_enc: DoublePackedSequence(packed: packed_g_amr_len x max_re_amr_len x dim), length= re_lens), out_l = g_amr_len """ head_emb, dep_emb = [], [] # unpacked_src_enc: list(batch, re_g_real_len x max_re_amr_len x dim ) unpacked_src_enc, _ = doubleunpack(src_enc) length_pairs = [] for i, index in enumerate( indexes): # indexes : list(batch, real_g_amr_len) enc = unpacked_src_enc[i] #g_amr_real_len x max_re_amr_l x dim dep_emb.append(enc.index_select( 1, index)) #list(batch_size, real_g_amr_l x real_g_amr_l x dim) # head_index : [real_g_amr_len, 1, 1] -> [real_g_amr_len, 1, dim] head_index = index.unsqueeze(1).unsqueeze(2).expand( enc.size(0), 1, enc.size(-1)) # logger.info("enc: {}, head_idnex :{}".format(enc, head_index)) head_emb.append(enc.gather( 1, head_index).squeeze(1)) # list(batch, real_g_amr_l x dim) length_pairs.append([len(index), len(index)]) # head_emb_t : MyPackedSequence(data: packed_real_g_amr_l x dim), g_amr_l # dep_emb_t : MyDoublePackedSequence(PackedSequenceLength(packed_real_g_amr_l x real_g_amr_l x dim), length_pairs) # length_pairs :(g_amr_l, g_amr_l) return mypack(head_emb, [ls[0] for ls in length_pairs]), mydoublepack( dep_emb, length_pairs), length_pairs
def forward(self, _, heads, deps): '''heads.data: mypacked amr_l x rel_dim deps.data: mydoublepacked amr_l x amr_l x rel_dim ''' heads_data = heads.data deps_data = deps.data head_bilinear_transformed = self.bilinear( heads_data) #all_data x ( n_rel x inputsize) head_bias_unpacked = myunpack(self.head_bias(heads_data), heads.lengths) #[len x n_rel] size = deps_data.size() dep_bias = self.dep_bias(deps_data.view(-1, size[-1])).view( size[0], size[1], -1) dep_bias_unpacked, length_pairs = mydoubleunpack( MyDoublePackedSequence(MyPackedSequence(dep_bias, deps[0][1]), deps[1], dep_bias)) #[len x n_rel] bilinear_unpacked = myunpack(head_bilinear_transformed, heads.lengths) deps_unpacked, length_pairs = mydoubleunpack(deps) output, l = self.bilinearForParallel( zip(bilinear_unpacked, deps_unpacked, head_bias_unpacked, dep_bias_unpacked), length_pairs) myscore_packed = mypack(output, l) # prob_packed = MyPackedSequence(myscore_packed.data,l) return myscore_packed
def getEmb(self,indexes,src_enc): head_emb,lengths = [],[] src_enc = myunpack(*src_enc) # pre_amr_l/src_l x batch x dim for i, index in enumerate(indexes): enc = src_enc[i] #src_l x dim head_emb.append(enc[index]) #var(amr_l x dim) lengths.append(len(index)) return mypack(head_emb,lengths)
def getEmb(self,indexes,src_enc): head_emb,dep_emb = [],[] src_enc,src_l = doubleunpack(src_enc) # batch x var(amr_l x src_l x dim) length_pairs = [] for i, index in enumerate(indexes): enc = src_enc[i] #amr_l src_l dim dep_emb.append(enc.index_select(1,index)) #var(amr_l x amr_l x dim) head_index = index.unsqueeze(1).unsqueeze(2).expand(enc.size(0),1,enc.size(-1)) # print ("getEmb",enc.size(),dep_index.size(),head_index.size()) head_emb.append(enc.gather(1,head_index).squeeze(1)) #var(amr_l x dim) length_pairs.append([len(index),len(index)]) return mypack(head_emb,[ls[0] for ls in length_pairs]),mydoublepack(dep_emb,length_pairs),length_pairs
def getEmb(self, indexes, src_enc): """ # index: rel_index_batch: list(batch, real_gold_amr_len), but content is the index of recatogrized amr index, is a mapping # src_enc: is weighted root_src_enc, MyPackedSequence(data: packed_re_amr_len x txt_enc_size, lengtgs: re_amr_lens) """ head_emb, lengths = [], [] src_enc = myunpack( *src_enc) # list(batch, real_re_amr_len x src_enc_size) for i, index in enumerate( indexes): # indexse, batch_size, real_gold_amr_lens enc = src_enc[i] #real_re_amr_len x src_enc_size, head_emb.append( enc[index] ) #the content of index if real_re_amd_index, list(batch, real_gold_amr_len, dim) lengths.append(len(index)) #list(batch, real_gold_amr_len) return mypack( head_emb, lengths ) # MyPackedSequence(data: packed_gold_amr_len x dim, lengths: readl_gold_amr_len)