def forward(self, inputs, hidden=None): if isinstance(inputs, PackedSequence): bsizes = inputs.batch_sizes max_batch = int(bsizes[0]) emb = PackedSequence(self.embedding_dropout( self.embedder(inputs.data)), bsizes) # Get padding mask time_dim = 1 if self.batch_first else 0 range_batch = torch.arange(0, max_batch, dtype=bsizes.dtype, device=bsizes.device) range_batch = range_batch.unsqueeze(time_dim) bsizes = bsizes.unsqueeze(1 - time_dim) padding_mask = (bsizes - range_batch).le(0) else: padding_mask = inputs.eq(PAD) emb = self.embedding_dropout(self.embedder(inputs)) outputs, hidden_t = self.rnn(emb, hidden) if isinstance(inputs, PackedSequence): outputs = unpack(outputs)[0] outputs = self.dropout(outputs) if hasattr(self, 'context_transform'): context = self.context_transform(outputs) else: context = None if hasattr(self, 'hidden_transform'): hidden_t = self.hidden_transform(hidden_t) state = State(outputs=outputs, hidden=hidden_t, context=context, mask=padding_mask, batch_first=self.batch_first) return state
def forward(self, inputs, lengths, hidden=None): lens, indices = torch.sort(inputs.data.new(lengths).long(), 0, True) inputs = inputs[indices] if self.batch_first else inputs[:, indices] outputs, (h, c) = self.rnn(pack(inputs, lens.tolist(), batch_first=self.batch_first), hidden) outputs = unpack(outputs, batch_first=self.batch_first)[0] _, _indices = torch.sort(indices, 0) outputs = outputs[_indices] if self.batch_first else outputs[:, _indices] h, c = h[:, _indices, :], h[:, _indices, :] return outputs, (h, c)
def forward(self, enc_input, hidden=None): if isinstance(enc_input, tuple): # Lengths data is wrapped inside a Variable. lengths = enc_input[1].data.view(-1).tolist() emb = pack(self.embedding(enc_input[0]), lengths) else: emb = self.embedding(enc_input) outputs, hidden_t = self.rnn(emb, hidden) if isinstance(enc_input, tuple): outputs = unpack(outputs)[0] return outputs, hidden_t
def forward(self, inputs, state): context, hidden = state.context, state.hidden if isinstance(inputs, PackedSequence): emb = PackedSequence(self.embedding_dropout( self.embedder(inputs.data)), inputs.batch_size) else: emb = self.embedding_dropout(self.embedder(inputs)) x, hidden_t = self.rnn(emb, hidden) if isinstance(inputs, PackedSequence): x = unpack(x)[0] x = self.dropout(x) x = self.classifier(x) return x, State(hidden=hidden_t, context=context, batch_first=self.batch_first)
def forward(self, src, lengths=None): """See :func:`onmt.encoders.encoder.EncoderBase.forward()`""" batch_size, _, nfft, t = src.size() src = src.transpose(0, 1).transpose(0, 3).contiguous() \ .view(t, batch_size, nfft) orig_lengths = lengths lengths = lengths.view(-1).tolist() for l in range(self.enc_layers): rnn = getattr(self, 'rnn_%d' % l) pool = getattr(self, 'pool_%d' % l) batchnorm = getattr(self, 'batchnorm_%d' % l) stride = self.enc_pooling[l] packed_emb = pack(src, lengths) memory_bank, tmp = rnn(packed_emb) memory_bank = unpack(memory_bank)[0] t, _, _ = memory_bank.size() memory_bank = memory_bank.transpose(0, 2) memory_bank = pool(memory_bank) lengths = [int(math.floor((length - stride) / stride + 1)) for length in lengths] memory_bank = memory_bank.transpose(0, 2) src = memory_bank t, _, num_feat = src.size() src = batchnorm(src.contiguous().view(-1, num_feat)) src = src.view(t, -1, num_feat) if self.dropout and l + 1 != self.enc_layers: src = self.dropout(src) memory_bank = memory_bank.contiguous().view(-1, memory_bank.size(2)) memory_bank = self.W(memory_bank).view(-1, batch_size, self.dec_rnn_size) state = memory_bank.new_full((self.dec_layers * self.num_directions, batch_size, self.dec_rnn_size_real), 0) if self.rnn_type == 'LSTM': # The encoder hidden is (layers*directions) x batch x dim. encoder_final = (state, state) else: encoder_final = state return encoder_final, memory_bank, orig_lengths.new_tensor(lengths)
def forward(self, src, lengths=None, encoder_state=None): "See :obj:`EncoderBase.forward()`" self._check_args(src, lengths, encoder_state) emb = self.embeddings(src) s_len, batch, emb_dim = emb.size() packed_emb = emb if lengths is not None and not self.no_pack_padded_seq: # Lengths data is wrapped inside a Variable. lengths = lengths.view(-1).tolist() packed_emb = pack(emb, lengths) memory_bank, encoder_final = self.rnn(packed_emb, encoder_state) if lengths is not None and not self.no_pack_padded_seq: memory_bank = unpack(memory_bank)[0] if self.use_bridge: encoder_final = self._bridge(encoder_final) return encoder_final, memory_bank
def forward(self, lang, src, lengths=None): """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) emb = self.embeddings[lang](src) # s_len, batch, emb_dim = emb.size() packed_emb = emb if lengths is not None and not self.no_pack_padded_seq: # Lengths data is wrapped inside a Tensor. lengths_list = lengths.view(-1).tolist() packed_emb = pack(emb, lengths_list) memory_bank, encoder_final = self.rnn(packed_emb) if lengths is not None and not self.no_pack_padded_seq: memory_bank = unpack(memory_bank)[0] if self.use_bridge: encoder_final = self._bridge(encoder_final) return encoder_final, memory_bank, lengths
def forward(self, src, lengths=None, encoder_state=None): "See :obj:`EncoderBase.forward()`" self._check_args(src, lengths, encoder_state) emb = self.embeddings(src) s_len, batch, emb_dim = emb.size() packed_emb = emb if lengths is not None and not self.no_pack_padded_seq: # Lengths data is wrapped inside a Variable. lengths = lengths.view(-1).tolist() packed_emb = pack(emb, lengths) memory_bank, encoder_final = self.rnn(packed_emb, encoder_state) if lengths is not None and not self.no_pack_padded_seq: memory_bank = unpack(memory_bank)[0] if self.use_bridge: encoder_final = self._bridge(encoder_final) return encoder_final, memory_bank
def forward(self, input, lengths=None, hidden=None, FLAG=True): "See :obj:`EncoderBase.forward()`" self._check_args(input, lengths, hidden) if FLAG: emb = self.embeddings(input) s_len, batch, emb_dim = emb.size() else: emb = input packed_emb = emb if lengths is not None and not self.no_pack_padded_seq: # Lengths data is wrapped inside a Variable. lengths = lengths.view(-1).tolist() packed_emb = pack(emb, lengths) outputs, hidden_t = self.rnn(packed_emb, hidden) if lengths is not None and not self.no_pack_padded_seq: outputs = unpack(outputs)[0] return hidden_t, outputs
def forward(self, emb): ''' 원문을 인코딩하는 영역 - time-step을 신경쓸 필요 없음, 한꺼번에 넘겨서 한꺼번에 받음 - PAD를 가진 시퀸스를 효율적으로 병렬연산 하기 위해, pack, unpack으로 처리 ''' if isinstance(emb, tuple): # x = (batch_size, length, word_vec_size) # lengths = (batch_size) - 각 문장마다의 길이 (PAD 제외) x, lengths = emb x = pack(x, lengths.tolist(), batch_first=True) else: x = emb # y = (batch_size, length, hidden_size) # h[0] = (num_layers * 2, batch_size, hidden_size / 2) y, h = self.rnn(x) if isinstance(emb, tuple): y, _ = unpack(y, batch_first=True) return y, h
def _get_lstm_features(self, token_ids, lengths): # |token_ids| = [batch_size, token_length] # |lengths| = [batch_size] embeds = self.word_embeds(token_ids) # |embeds| = [batch_size, token_length, hidden_dim] packed_embeds = pack(embeds, lengths=lengths.tolist(), batch_first=True, enforce_sorted=False) # |embeds| = [batch_size, token_length, hidden_dim] # Apply RNN and get hiddens layers of each words last_hiddens, _ = self.rnn(packed_embeds) # Unpack ouput of rnn model last_hiddens, _ = unpack(last_hiddens, batch_first=True) # |last_hiddens| = [batch_size, max(token_length), hidden_size] lstm_feats = self.hidden2tag(self.tanh(last_hiddens)) return lstm_feats
def forward(self, input, lengths, cnn=True): embs = pack(self.embedding(input), lengths) outputs, state = self.rnn(embs) outputs = unpack(outputs)[0] if not self.config.bidirec: return outputs, state else: outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self. hidden_size:] state = (state[0][1::2], state[1][1::2]) o_ = outputs if cnn: outputs = outputs.transpose(0, 1).transpose(1, 2) outputs = self.selu1(self.conv1(outputs)) outputs = self.selu2(self.conv2(outputs)) outputs = self.selu3(self.conv3(outputs)) conv = outputs.transpose(1, 2).transpose(0, 1) # outputs = self.sigmoid(outputs) * o_ outputs = o_ return outputs, conv, state
def forward(self, inputs, lengths, hidden=None): """A pretrained MT-LSTM (McCann et. al. 2017). This LSTM was trained with 300d 840B GloVe on the WMT 2017 machine translation dataset. Arguments: inputs (Tensor): If MTLSTM handles embedding, a Long Tensor of size (batch_size, timesteps). Otherwise, a Float Tensor of size (batch_size, timesteps, features). lengths (Long Tensor): (batch_size, lengths) lenghts of each sequence for handling padding hidden (Float Tensor): initial hidden state of the LSTM """ if self.embed: inputs = self.vectors(inputs) lens, indices = torch.sort(lengths, 0, True) outputs, hidden_t = self.rnn( pack(inputs[indices], lens.tolist(), batch_first=True), hidden) outputs = unpack(outputs, batch_first=True)[0] _, _indices = torch.sort(indices, 0) outputs = outputs[_indices] if self.residual_embeddings: outputs = torch.cat([inputs, outputs], 2) return outputs
def forward(self, tokens_emb, length): batch_size = tokens_emb.size(0) tokens_emb = pack(tokens_emb, length, batch_first=True) outputs, states_t = self.rnn(tokens_emb) reps, _ = unpack(outputs, batch_first=True) # print 'reps', reps size = reps.size() compressed_reps = reps.contiguous().view( -1, size[2]) # (batch_size x seq_len) * mem_size hbar = self.tanh( self.ws1(compressed_reps)) # (batch_size x seq_len) * attn_size alphas = self.ws2(hbar).view(size[0], size[1], -1) # batch_size * seq_len * hops alphas = torch.transpose(alphas, 1, 2).contiguous() # batch_size * hops * seq_len mask = self.get_mask(length) # print 'mask', mask multi_mask = [mask.unsqueeze(1) for i in range(self.hops)] multi_mask = torch.cat(multi_mask, 1) # print 'multi_mask', multi_mask penalized_alphas = alphas + -1e7 * (1 - multi_mask) alphas = self.sm(penalized_alphas.view( -1, size[1])) # (batch_size x hops) * seq_len alphas = alphas.view(size[0], self.hops, size[1]) # batch_size * hops * seq_len # print 'alphas', alphas reps = torch.bmm(alphas, reps) # batch_size * hops * hidden_size # here we use mean pooling of all hops rep = reps.mean(1) assert len(rep.size()) == 2 # batch_size * classes, batch_size * hops * seq_len return self.dropout(rep), alphas
def forward(self, enc_outs, target, source_rep_mask=None, target_length=None): lstm_in, _, lstm_states = self._prepare(enc_outs, target) # including EOE token # target_length += 1 ### LSTM target_length, indices = torch.sort(target_length, 0, True) lstm_in_sorted = self.reorder_sequence(lstm_in, indices) lstm_in_packed = pack(lstm_in_sorted, target_length.tolist(), batch_first=True) # lstm_out: batch_size x num_target x hidden_units lstm_out_packed, _ = self.lstm(lstm_in_packed, lstm_states) lstm_out, _ = unpack(lstm_out_packed, batch_first=True) _, reverse_indices = torch.sort(indices, 0) lstm_out = self.reorder_sequence(lstm_out, reverse_indices) # lstm_out, _ = self.lstm(lstm_in, lstm_states) ### glimpse attention # glimpse: batch_size x num_target x hidden_units glimpse, _ = self.glimpse_attn(lstm_out, enc_outs, rep_mask=source_rep_mask) ### point attention # probs: batch_size x num_target x num_sentence _, logits = self.point_attn(glimpse, enc_outs, rep_mask=source_rep_mask) return logits
def posteriorIndictedEmb(self,embs,posterior): #real alignment is sent in as list of index #variational relaxed posterior is sent in as MyPackedSequence #out (batch x amr_len) x src_len x (dim+1) embs,src_len = unpack(embs) if isinstance(posterior,MyPackedSequence): # print ("posterior is packed") posterior = myunpack(*posterior) embs = embs.transpose(0,1) out = [] lengths = [] amr_len = [len(p) for p in posterior] for i,emb in enumerate(embs): expanded_emb = emb.unsqueeze(0).expand([amr_len[i]]+[i for i in emb.size()]) # amr_len x src_len x dim indicator = posterior[i].unsqueeze(2) # amr_len x src_len x 1 out.append(torch.cat([expanded_emb,indicator],2)) # amr_len x src_len x (dim+1) lengths = lengths + [src_len[i]]*amr_len[i] data = torch.cat(out,dim=0) return pack(data, lengths, batch_first=True), amr_len elif isinstance(posterior,list): embs = embs.transpose(0,1) src_l = embs.size(1) amr_len = [len(i) for i in posterior] out = [] lengths = [] for i,emb in enumerate(embs): amr_l = len(posterior[i]) expanded_emb = emb.unsqueeze(0).expand([amr_l]+[i for i in emb.size()]) # amr_len x src_len x dim indicator = emb.data.new(amr_l,src_l).zero_() indicator.scatter_(1, posterior[i].data.unsqueeze(1), 1.0) # amr_len x src_len x 1 indicator = Variable(indicator.unsqueeze(2)) out.append(torch.cat([expanded_emb,indicator],2)) # amr_len x src_len x (dim+1) lengths = lengths + [src_len[i]]*amr_l data = torch.cat(out,dim=0) return pack(data,lengths,batch_first=True),amr_len
def enc(self, fbank, len=None): res = self.enc_fnn_lyr(fbank) for rnn, skip, drop in zip(self.enc_rnn_lyr, cf.enc_rnn_skip, cf.enc_rnn_drop): if len is not None: res = pack(res, len, batch_first=True) res, (h, c) = rnn(res) if len is not None: res, _ = unpack(res, batch_first=True) len = [x // skip for x in len] res = F.dropout(res[:, ::skip], drop, self.training) #res=F.normalize(res,dim=-1) res = self.enc_dec_conection(res) if rnn.bidirectional: c = self.enc_dec_conection(torch.cat((c[0], c[1]), -1)) h = self.enc_dec_conection(torch.cat((h[0], h[1]), -1)) else: c = self.enc_dec_conection(c[0]) h = self.enc_dec_conection(h[1]) return res, (h, c)
def get_vectors(self, input, lengths=None): embed_input = self.embed(input) packed_emb = embed_input if lengths is not None: lengths = lengths.view(-1).tolist() packed_emb = nn.utils.rnn.pack_padded_sequence( embed_input, lengths) output, hidden = self.encoder(packed_emb) # embed_input if lengths is not None: output = unpack(output)[0] # MUST apply negative mapping, so max pooling will not take padding elements batch_mask = self.create_mask(lengths) # (time, batch_size) batch_mask = batch_mask.view( -1, len(lengths), 1) # not sure if here broadcasting is right output = self.exp_mask(output, batch_mask) # now pads will never be chosen... return output
def get_hypothesis_logits(self, words_hyp): mask_hyp = torch.ne(words_hyp, constants.PAD_ID) h_hyp = self.word_emb(words_hyp) lengths = mask_hyp.int().sum(dim=-1) h_hyp = pack(h_hyp, lengths, batch_first=True, enforce_sorted=False) h_hyp, hidden_hyp = self.rnn_hyp(h_hyp) h_hyp, _ = unpack(h_hyp, batch_first=True) if self.rnn_type == 'lstm': hidden_hyp = hidden_hyp[0] if self.is_bidir: hidden_states = [hidden_hyp[0], hidden_hyp[1]] else: hidden_states = [hidden_hyp[0]] hyp_logits = torch.cat(hidden_states, dim=-1).unsqueeze(1) # get last valid outputs instead of the last hidden state: # last_valid_idx = mask_hyp.int().sum(dim=-1) - 1 # arange_vector = torch.arange(h_hyp.shape[0]).to(h_hyp.device) # hyp_logits = h_hyp[arange_vector, last_valid_idx].unsqueeze(1) return hyp_logits
def encode_batch(self, inputs, trans, lengths): bsz, max_len = inputs.size() in_embs = self.word_embs(inputs) lens, indices = torch.sort(lengths, 0, True) # concat word embs with trans hid if self.use_input_parse: in_embs = torch.cat([in_embs, trans.unsqueeze(1).expand(bsz, max_len, self.d_trans)], 2) e_hid_init = self.e_hid_init.expand(2, bsz, self.d_hid).contiguous() e_cell_init = self.e_cell_init.expand(2, bsz, self.d_hid).contiguous() all_hids, (enc_last_hid, _) = self.encoder(pack(in_embs[indices], lens.tolist(), batch_first=True), (e_hid_init, e_cell_init)) _, _indices = torch.sort(indices, 0) all_hids = unpack(all_hids, batch_first=True)[0][_indices] all_hids = self.encoder_proj(all_hids.view(-1, self.d_hid * 2)).view(bsz, max_len, self.d_hid) enc_last_hid = torch.cat([enc_last_hid[0], enc_last_hid[1]], 1) enc_last_hid = self.encoder_proj(enc_last_hid)[_indices] return all_hids, enc_last_hid
def forward(self, input, heads, lengths=None, hidden=None): """ See EncoderBase.forward() for description of args and returns. inputs: [L, B, H], including the -ROOT- heads: [heads] * B """ emb = self.dropout(input) packed_emb = emb if lengths is not None: # Lengths data is wrapped inside a Variable. packed_emb = pack(emb, lengths) outputs, hidden_t = self.rnn(packed_emb, hidden) if lengths is not None: outputs = unpack(outputs)[0] outputs = self.dropout(self.transform(outputs)) max_length, batch_size, input_dim = outputs.size() trees = [] indexes = np.full((max_length, batch_size), -1, dtype=np.int32) # a col is a sentence for b, head in enumerate(heads): root, tree = creatTree( head) # head: a sentence's heads; sentence base root.traverse() # traverse the tree for step, index in enumerate(root.order): indexes[step, b] = index trees.append(tree) dt_outputs, dt_hidden_ts = self.dt_tree.forward( outputs, indexes, trees) td_outputs, td_hidden_ts = self.td_tree.forward( outputs, indexes, trees) outputs = torch.cat([dt_outputs, td_outputs], dim=2).transpose(0, 1) output_t = torch.cat([dt_hidden_ts, td_hidden_ts], dim=1).unsqueeze(0) return outputs, output_t
def enc(self,src,len=None): res=F.dropout(self.enc_emb_lyr(src),enc_emb_drop,self.training) h,c=None,None #import pdb; pdb.set_trace() for rnn,skip,drop in zip(self.enc_rnn_lyr,enc_rnn_skip,enc_rnn_drop): res = res if len is None else pack(res, len, batch_first=True) res,(h,c) =rnn(res) if len is not None: # import pdb; pdb.set_trace() res,_=unpack(res, batch_first=True) len =[x // skip for x in len] res = F.dropout(res[:,::skip],drop,self.training) #res=F.normalize(,dim=-1) res=self.enc_dec_conection(res) if rnn.bidirectional: c=self.enc_dec_conection(torch.cat((c[0],c[1]),-1)) h=self.enc_dec_conection(torch.cat((h[0],h[1]),-1)) else: c=self.enc_dec_conection(c[0]) h=self.enc_dec_conection(h[1]) return res,(h,c)
def forward(self, x, lens, k, kx): # model takes as input the text, aspect, and location # runs BLSTM over text using embedding(location, aspect) as # the initial hidden state, as opposed to a different lstm for every pair??? # output sentiment # DBG words = x emb = self.drop(self.lut(x)) p_emb = pack(emb, lens, True) l, a = k N = x.shape[0] T = x.shape[1] y_idx = l * len(self.A) + a if self.L is not None else a s = (self.lut_la(y_idx) .view(N, 2, 2 * self.nlayers, self.rnn_sz) .permute(1, 2, 0, 3) .contiguous()) state = (s[0], s[1]) x, (h, c) = self.rnn(p_emb, state) # h: L * D x N x H x = unpack(x, True)[0] phi_s = self.proj_s(x) idxs = torch.arange(0, max(lens)).to(lens.device) # mask: N x R x 1 mask = (idxs.repeat(len(lens), 1) >= lens.unsqueeze(-1)) phi_s[:,:,-1].masked_fill_(1-mask, float("-inf")) phi_s[:,:,:3].masked_fill_(mask.unsqueeze(-1), float("-inf")) phi_y = torch.zeros(N, len(self.S)).to(self.psi_ys.device) psi_ys = torch.cat( [torch.diag(self.psi_ys), torch.zeros(len(self.S), 1).to(self.psi_ys)], dim=-1, ).expand(T, len(self.S), len(self.S)+1) Z, hy = ubersum("nts,tys,ny->n,ny", phi_s, psi_ys, phi_y, batch_dims="t", modulo_total=True) return hy
def encode(self, input, hidden): ''' input :: bs, sl return output :: bs, sl, nh*directions hidden :: n_layers*directions,bs, nh ''' mask = torch.gt(input.data,0) input_length = torch.sum((mask.long()),dim=1) # batch first = True, (batch, sl) lengths, indices = torch.sort(input_length, dim=0, descending=True) _, ind = torch.sort(indices, dim=0) input_length = torch.unbind(lengths, dim=0) embedded = self.embedding(torch.index_select(input,dim=0,index=Variable(indices))) output, hidden = self.gru(pack(embedded, input_length, batch_first=True), hidden) output = torch.index_select(unpack(output, batch_first=True)[0], dim=0,index=Variable(ind))*Variable(torch.unsqueeze(mask.float(),-1)) hidden = torch.index_select(hidden[-1], dim=0, index=Variable(ind)) #hidden = torch.unbind(hidden, dim=0) #hidden = torch.cat(hidden, 1) direction = 2 if self.bidirectional else 1 assert hidden.size() == (input.size()[0],self.hidden_size) and output.size() == (input.size()[0], input.size()[1],self.hidden_size*direction) return output, hidden
def forward(self, input, hidden=None): if isinstance(input, tuple): emb_ = self.word_lut(input[0]) if self.src_fix_emb: emb = pack(self.emb2input(emb_), list(input[1])) else: emb = pack(emb_, list(input[1])) else: emb_ = self.word_lut(input) if self.src_fix_emb: emb = self.emb2input(emb_) else: emb = emb_ # if isinstance(input, tuple): # emb = pack(self.word_lut(input[0]), list(input[1])) # else: # emb = self.word_lut(input) outputs, hidden_t = self.rnn(emb, hidden) if isinstance(input, tuple): outputs = unpack(outputs)[0] return hidden_t, outputs
def forward(self, word, context, pool_type="max"): word_emb = self.embedding(word) context_emb = self.embedding(context) lengths = (context != constants.PAD_IDX).sum(dim=0).detach().cpu() # Sort by length (keep idx) context_len_sorted, idx_sort = np.sort(lengths.numpy())[::-1], np.argsort(-lengths.numpy()) context_len_sorted = torch.from_numpy(context_len_sorted.copy()) idx_unsort = np.argsort(idx_sort) context_emb = context_emb.index_select(1, torch.from_numpy(idx_sort).to(self.device)) context_emb = pack(context_emb, context_len_sorted, batch_first=False) context_vec, _ = self.rnn(context_emb, None) context_vec = unpack(context_vec, batch_first=False)[0] # Un-sort by length context_vec = context_vec.index_select(1, torch.from_numpy(idx_unsort).to(self.device)) # Pooling if pool_type == "mean": lengths = torch.FloatTensor(lengths.numpy().copy()).unsqueeze(1) emb = torch.sum(context_vec, 0).squeeze(0) if emb.ndimension() == 1: emb = emb.unsqueeze(0) emb = emb / lengths.expand_as(emb).to(self.device) elif pool_type == "max": emb = torch.max(context_vec, 0)[0] if emb.ndimension() == 3: emb = emb.squeeze(0) assert emb.ndimension() == 2 V = F.relu(self.linear_q(word_emb.unsqueeze(1))) C = F.relu(self.linear_k(torch.transpose(context_vec, 0, 1))) T = torch.transpose(context_vec, 0, 1) scale = (C.size(-1)) ** -0.5 att = torch.bmm(V, C.transpose(1, 2)) * scale att = self.softmax(att) att = self.dropout(att) c = torch.bmm(att, T) c = c.squeeze(1) return c, emb, att
def forward(self, x, y, x_len=None, softmax=True): """ Args: x(tensor): batch, frame, dim y(tensor): batch, frame x_len(tensor): batch """ if self.pack_seq and x_len is not None: packed_x = pack(x, x_len, batch_first=True, enforce_sorted=True) packed_x, _ = self.encoder(packed_x) x = unpack(packed_x, batch_first=True)[0] else: x = self.encoder(x) #prepend SOS, blk=0 SOS = Variable(torch.zeros(y.shape[0], 1).long()) SOS = SOS.cuda() y = torch.cat((SOS, y), dim=1) if self.decoder_type == 'rnn': y = self.embed(y) y, _ = self.decoder(y) else: y = self.decoder(y) T = x.size()[1] U = y.size()[1] #x: batch, T, U, dim #y: batch, T, U, dim x = x.unsqueeze(2).expand(-1, -1, U, -1) y = y.unsqueeze(1).expand(-1, T, -1, -1) #x_gate = F.glu(torch.cat((x, y), dim=-1), dim=-1) #y_gate = F.glu(torch.cat((y, x), dim=-1), dim=-1) #out = torch.cat((x_gate, y_gate), dim=-1) out = torch.cat((x, y), dim=-1) out = self.fc2(F.tanh(self.fc1(out)) * F.sigmoid(self.fc_gate(out))) #out = self.fc2(F.selu(self.fc1(out))) if softmax: out = F.log_softmax(out, dim=-1) return out
def forward(self, input, lengths=None, hidden=None): packed_emb = input packed_emb = pack(input, lengths, enforce_sorted=False) outputs, hidden_t = self.rnn(packed_emb, hidden) outputs = unpack(outputs)[0] # consider both direction if self.bidirectional: if self.rnn_type == 'LSTM': h_n, c_n = hidden_t h_n = torch.cat([h_n[0:h_n.size(0):2], h_n[1:h_n.size(0):2]], 2) c_n = torch.cat([c_n[0:c_n.size(0):2], c_n[1:c_n.size(0):2]], 2) hidden_t = (h_n, c_n) else: hidden_t = torch.cat([ hidden_t[0:hidden_t.size(0):2], hidden_t[1:hidden_t.size(0):2] ], 2) return outputs, hidden_t
def sinkhorn_score_regularizor(score): '''probBatch: tuple (src_len x batch x n_out,lengths), tgtBatch: amr_len x batch x n_feature , lengths score = packed( amr_len x batch x src_len , lengths) total_loss,total_data ''' scores, lengths = unpack(score) S = 0 r = opt.prior_t / opt.sink_t gamma_r = math.gamma(1 + r) for i, l in enumerate(lengths): # scores[:l, i, :l].data = torch.clamp(scores[:l, i, :l].data, 0, torch.max(scores[:l, i, :l].data)) # print("scores", torch.max(scores[:l, i, :l].data), torch.min(scores[:l, i, :l].data)) # aa = Variable(torch.randn(3, 5) * 50) # print(scores[:l, i, :l]) scores[:l, i, :l] = torch.clamp(scores[:l, i, :l], min=-1) S = S + r / scores[:l, i, :l].sum() + gamma_r * torch.exp( -scores[:l, i, :l] * r).sum() return S #+activation_loss
def apply_packed_sequence(rnn, embedding, lengths): """ Runs a forward pass of embeddings through an rnn using packed sequence. Args: rnn: The RNN that that we want to compute a forward pass with. embedding (FloatTensor b x seq x dim): A batch of sequence embeddings. lengths (LongTensor batch): The length of each sequence in the batch. Returns: output: The output of the RNN `rnn` with input `embedding` """ # Sort Batch by sequence length lengths_sorted, permutation = torch.sort(lengths, descending=True) embedding_sorted = embedding[permutation] # Use Packed Sequence embedding_packed = pack(embedding_sorted, lengths_sorted, batch_first=True) outputs_packed, (hidden, cell) = rnn(embedding_packed) outputs_sorted, _ = unpack(outputs_packed, batch_first=True) # Restore original order _, permutation_rev = torch.sort(permutation, descending=False) outputs = outputs_sorted[permutation_rev] hidden, cell = hidden[:, permutation_rev], cell[:, permutation_rev] return outputs, (hidden, cell)
def forward(self, input, hidden=None, is_fert=True): if isinstance(input, tuple): emb_ = self.word_lut(input[0]) if self.src_fix_emb: emb_ = self.emb2input(emb_) emb = pack(emb_, list(input[1])) else: emb = pack(emb_, list(input[1])) else: emb_ = self.word_lut(input) if self.src_fix_emb: emb = self.emb2input(emb_) else: emb = emb_ # if isinstance(input, tuple): # emb = pack(self.word_lut(input[0]), list(input[1])) # else: # emb = self.word_lut(input) outputs, hidden_t = self.rnn(emb, hidden) if isinstance(input, tuple): outputs = unpack(outputs)[0] cov = None if self.use_fert: if self.fert_mode == "emb": cov_inp = emb_ elif self.fert_mode == "emh": cov_inp = torch.cat([emb_, outputs], -1) else: cov_inp = outputs if is_fert: cov = self.forward_cov(cov_inp) hidden_t = (self._fix_enc_hidden(hidden_t[0]), self._fix_enc_hidden(hidden_t[1])) return hidden_t, outputs, cov, cov_inp
def forward(self, inputs, hidden=None): if isinstance(inputs, PackedSequence): emb = PackedSequence( self.embedding_dropout(self.embedder(inputs.data)), inputs.batch_sizes) bsizes = inputs.batch_sizes.to(device=inputs.data.device) max_batch = int(bsizes[0]) # Get padding mask time_dim = 1 if self.batch_first else 0 range_batch = torch.arange(0, max_batch, dtype=bsizes.dtype, device=bsizes.device) range_batch = range_batch.unsqueeze(time_dim) bsizes = bsizes.unsqueeze(1 - time_dim) padding_mask = (bsizes - range_batch).le(0) else: padding_mask = inputs.eq(PAD) emb = self.embedding_dropout(self.embedder(inputs)) outputs, hidden_t = self.rnn(emb, hidden) if isinstance(inputs, PackedSequence): outputs = unpack(outputs)[0] outputs = self.dropout(outputs) if hasattr(self, 'context_transform'): context = self.context_transform(outputs) else: context = None if hasattr(self, 'hidden_transform'): hidden_t = self.hidden_transform(hidden_t) state = State(outputs=outputs, hidden=hidden_t, context=context, mask=padding_mask, batch_first=self.batch_first) return state
def forward(self, src, lengths=None): """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) emb = self.embeddings(src) # s_len, batch, emb_dim = emb.size() packed_emb = emb if lengths is not None and not self.no_pack_padded_seq: # Lengths data is wrapped inside a Tensor. lengths_list = lengths.view(-1).tolist() packed_emb = pack(emb, lengths_list) memory_bank, encoder_final = self.layers[0](packed_emb) if lengths is not None and not self.no_pack_padded_seq: memory_bank = unpack(memory_bank)[0] bottom_layers = 1 if self.gnmt: memory_bank = self.dropout(memory_bank) memory_bank, enc_final = self.layers[1](memory_bank) #print(encoder_final[0].size()) encoder_final0 = torch.cat((encoder_final[0], enc_final[0]), 0) encoder_final1 = torch.cat((encoder_final[1], enc_final[1]), 0) encoder_final = (encoder_final0, encoder_final1) bottom_layers = 2 for i in range(bottom_layers, self.num_layers): residual = memory_bank memory_bank = self.dropout(memory_bank) memory_bank, enc_final = self.layers[i](memory_bank) encoder_final0 = torch.cat((encoder_final[0], enc_final[0]), 0) encoder_final1 = torch.cat((encoder_final[1], enc_final[1]), 0) encoder_final = (encoder_final0, encoder_final1) if self.num_layers >= 4: memory_bank = memory_bank + residual return encoder_final, memory_bank, lengths
def forward(self, inputs, lengths): inputs = inputs.t() lengths = lengths.tolist() embs = pack(self.embedding(inputs), lengths) self.rnn.flatten_parameters() outputs, state = self.rnn(embs) outputs = unpack(outputs)[0] if self.bidirectional: outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] # Batch_size * Length * Hidden_size if self.inception: outputs = outputs.transpose(0, 1).transpose(1, 2) conv1 = self.sw1(outputs) conv3 = self.sw3(outputs) conv33 = self.sw33(outputs) conv = torch.cat((conv1, conv3, conv33), 1) conv = self.filter_linear(conv.transpose(1, 2)) #conv = self.sw3(outputs).transpose(1, 2) outputs = outputs.transpose(1, 2).transpose(0, 1) #seq_len, batch, dim outputs = outputs.transpose(0, 1) if self.encoding_gate: if self.gtu: # conv = "weight norm" # outputs "weight norm" gate = self.sigmoid(conv) tan_conv = torch.tanh(outputs) gtu_out = tan_conv * gate return self.layer_normalization(gtu_out + outputs) else: gate = self.sigmoid(conv) outputs = outputs * gate return outputs else: return conv else: return outputs.transpose(0, 1)
def forward(self, inputs, lengths): embs = pack(self.embedding(inputs), lengths) outputs, state = self.rnn(embs) outputs = unpack(outputs)[0] if self.config.bidirectional: if self.config.swish: outputs = self.linear(outputs) else: outputs = outputs[:, :, :self.config. hidden_size] + outputs[:, :, self.config. hidden_size:] if self.config.swish: outputs = outputs.transpose(0, 1).transpose(1, 2) conv1 = self.sw1(outputs) conv3 = self.sw3(outputs) conv33 = self.sw33(outputs) conv = torch.cat((conv1, conv3, conv33), 1) conv = self.filter_linear(conv.transpose(1, 2)) if self.config.selfatt: conv = conv.transpose(0, 1) outputs = outputs.transpose(1, 2).transpose(0, 1) else: gate = self.sigmoid(conv) outputs = outputs * gate.transpose(1, 2) outputs = outputs.transpose(1, 2).transpose(0, 1) if self.config.selfatt: self.attention.init_context(context=conv) out_attn, weights = self.attention(conv, selfatt=True) gate = self.sigmoid(out_attn) outputs = outputs * gate if self.config.cell == 'gru': state = state[:self.config.dec_num_layers] else: state = (state[0][::2], state[1][::2]) return outputs, state
def encode(self, inputs, lengths, fr=0): bsz, max_len = inputs.size() e_hidden_init = self.e_hidden_init.expand( 2, bsz, self.hidden_dim).contiguous() e_cell_init = self.e_cell_init.expand(2, bsz, self.hidden_dim).contiguous() lens, indices = torch.sort(lengths, 0, True) if fr and not self.share_vocab: in_embs = self.embedding_fr(inputs) else: in_embs = self.embedding(inputs) if fr and not self.share_encoder: if self.dropout > 0: in_embs = F.dropout(in_embs, p=self.dropout, training=self.training) all_hids, (enc_last_hid, _) = self.lstm_fr( pack(in_embs[indices], lens.tolist(), batch_first=True), (e_hidden_init, e_cell_init)) else: if self.dropout > 0: in_embs = F.dropout(in_embs, p=self.dropout, training=self.training) all_hids, (enc_last_hid, _) = self.lstm( pack(in_embs[indices], lens.tolist(), batch_first=True), (e_hidden_init, e_cell_init)) _, _indices = torch.sort(indices, 0) all_hids = unpack(all_hids, batch_first=True)[0][_indices] if self.pool == "max": embs = utils.max_pool(all_hids, lengths, self.gpu) elif self.pool == "mean": embs = utils.mean_pool(all_hids, lengths, self.gpu) return embs