def forward(self, input, lengths=None, hidden=None): """ Args: input (LongTensor): len x batch x nfeat lengths (LongTensor): batch hidden: Initial hidden state. Returns: hidden_t (FloatTensor): Pair of layers x batch x rnn_size - final Encoder state outputs (FloatTensor): len x batch x rnn_size - Memory bank """ # CHECKS s_len, n_batch, n_feats = input.size() if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_) # END CHECKS emb = self.embeddings(input) s_len, n_batch, vec_size = emb.size() if self.encoder_type == "mean": # No RNN, just take mean as final state. mean = emb.mean(0) \ .expand(self.layers, n_batch, vec_size) return (mean, mean), emb elif self.encoder_type == "transformer": # Self-attention tranformer. out = emb.transpose(0, 1).contiguous() for i in range(self.layers): out = self.transformer[i](out, input[:, :, 0].transpose(0, 1)) return Variable(emb.data), out.transpose(0, 1).contiguous() else: # Standard RNN encoder. packed_emb = emb if lengths is not None: # Lengths data is wrapped inside a Variable. lengths = lengths.view(-1).tolist() packed_emb = pack(emb, lengths) outputs, hidden_t = self.rnn(packed_emb, hidden) if lengths: outputs = unpack(outputs)[0] return hidden_t, outputs
def forward(self, input, context, state): """ Forward through the decoder. Args: input (LongTensor): a sequence of input tokens tensors of size (len x batch x nfeats). context (FloatTensor): output(tensor sequence) from the Encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the Encoder RNN for initializing the decoder. Returns: outputs (FloatTensor): a Tensor sequence of output from the Decoder of shape (len x batch x hidden_size). state (FloatTensor): final hidden state from the Decoder. attns (dict of (str, FloatTensor)): a dictionary of different type of attention Tensor from the Decoder of shape (src_len x batch). """ # Args Check assert isinstance(state, RNNDecoderState) input_len, input_batch, _ = input.size() contxt_len, contxt_batch, _ = context.size() aeq(input_batch, contxt_batch) # END Args Check # Run the forward pass of the RNN. hidden, outputs, attns, coverage = \ self._run_forward_pass(input, context, state) # Update the DecoderState with the result. final_output = outputs[-1] state = RNNDecoderState( hidden, final_output.unsqueeze(0), coverage.unsqueeze(0) if coverage is not None else None) # Concatenates sequence of tensors along a new dimension. outputs = torch.stack(outputs) for k in attns: attns[k] = torch.stack(attns[k]) return outputs, state, attns
def score(self, h_t, h_s): """ h_t (FloatTensor): batch x dim h_s (FloatTensor): batch x src_len x dim returns scores (FloatTensor): batch x src_len: raw attention scores for each src index """ # Check input sizes src_batch, _, src_dim = h_s.size() tgt_batch, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t = self.linear_in(h_t) return torch.bmm(h_s, h_t.unsqueeze(2)).squeeze(2) else: # MLP # batch x 1 x dim wq = self.linear_query(h_t).unsqueeze(1) # batch x src_len x dim uh = self.linear_context(h_s.contiguous()) # batch x src_len x dim wquh = uh + wq.expand_as(uh) # batch x src_len x dim wquh = self.tanh(wquh) # batch x src_len return self.v(wquh.contiguous()).squeeze(2)
def forward(self, input): """ Return the embeddings for words, and features if there are any. Args: input (LongTensor): len x batch x nfeat Return: emb (FloatTensor): len x batch x self.embedding_size """ in_length, in_batch, nfeat = input.size() aeq(nfeat, len(self.emb_luts)) emb = self.make_embedding(input) out_length, out_batch, emb_size = emb.size() aeq(in_length, out_length) aeq(in_batch, out_batch) aeq(emb_size, self.embedding_size) return emb
def forward(self, src_input): """ Embed the words or utilize features and MLP. Args: src_input (LongTensor): len x batch x nfeat Return: emb (FloatTensor): len x batch x emb_size emb_size is word_vec_size if there are no features or the merge action is sum. It is the sum of all feature dimensions if the merge action is concatenate. """ in_length, in_batch, nfeat = src_input.size() aeq(nfeat, len(self.emb_luts)) if len(self.emb_luts) == 1: emb = self.word_lut(src_input.squeeze(2)) else: feat_inputs = (feat.squeeze(2) for feat in src_input.split(1, dim=2)) features = [ lut(feat) for lut, feat in zip(self.emb_luts, feat_inputs) ] emb = self.merge(features) if self.positional_encoding: emb = emb + Variable( self.pe[:emb.size(0), :1, :emb.size(2)].expand_as(emb)) emb = self.dropout(emb) out_length, out_batch, emb_size = emb.size() aeq(in_length, out_length) aeq(in_length, out_length) aeq(emb_size, self.embedding_size) return emb
def forward(self, src_input): """ Embed the words or utilize features and MLP. Args: src_input (LongTensor): len x batch x nfeat Return: emb (FloatTensor): len x batch x emb_size emb_size is word_vec_size """ in_length, in_batch, nfeat = src_input.size() aeq(nfeat, len(self.emb_luts)) if len(self.emb_luts) == 1: emb = self.word_lut(src_input.squeeze(2)) out_length, out_batch, emb_size = emb.size() aeq(in_length, out_length) aeq(in_length, out_length) aeq(emb_size, self.embedding_size) return emb
def forward(self, input, context, src_words, tgt_words): # CHECKS n_batch, t_len, _ = input.size() n_batch_, s_len, _ = context.size() n_batch__, s_len_ = src_words.size() n_batch___, t_len_ = tgt_words.size() aeq(n_batch, n_batch_, n_batch__, n_batch___) aeq(s_len, s_len_) aeq(t_len, t_len_) # END CHECKS attn_mask = get_attn_padding_mask(tgt_words, tgt_words) dec_mask = torch.gt( attn_mask + self.mask[:, :attn_mask.size(1), :attn_mask.size(1)].expand_as( attn_mask), 0) pad_mask = get_attn_padding_mask(tgt_words, src_words) query, attn = self.self_attn(input, input, input, mask=dec_mask) mid, attn = self.context_attn(context, context, query, mask=pad_mask) output = self.feed_forward(mid) return output, attn
def score(self, h_t, h_s): """ h_t (FloatTensor): batch x tgt_len x dim h_s (FloatTensor): batch x src_len x dim returns scores (FloatTensor): batch x tgt_len x src_len: raw attention scores for each src index """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) h_t_ = self.linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = self.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def forward(self, input, src, context, state, fertility_vals=None, fert_dict=None, fert_sents=None, upper_bounds=None, test=False): """ Forward through the decoder. Args: input (LongTensor): (len x batch) -- Input tokens src (LongTensor) context: (src_len x batch x rnn_size) -- Memory bank state: an object initializing the decoder. Returns: outputs: (len x batch x rnn_size) final_states: an object of the same form as above attns: Dictionary of (src_len x batch) """ # CHECKS t_len, n_batch = input.size() s_len, n_batch_, _ = src.size() s_len_, n_batch__, _ = context.size() aeq(n_batch, n_batch_, n_batch__) # aeq(s_len, s_len_) # END CHECKS if self.decoder_layer == "transformer": if state.previous_input: input = torch.cat([state.previous_input.squeeze(2), input], 0) emb = self.embeddings(input.unsqueeze(2)) # n.b. you can increase performance if you compute W_ih * x for all # iterations in parallel, but that's only possible if # self.input_feed=False outputs = [] # Setup the different types of attention. attns = {"std": []} if self._copy: attns["copy"] = [] if self._coverage: attns["coverage"] = [] if self.exhaustion_loss: attns["upper_bounds"] = [] if self.fertility_loss: attns["predicted_fertility_vals"] = [] attns["true_fertility_vals"] = [] if self.decoder_layer == "transformer": # Tranformer Decoder. assert isinstance(state, TransformerDecoderState) output = emb.transpose(0, 1).contiguous() src_context = context.transpose(0, 1).contiguous() for i in range(self.layers): output, attn \ = self.transformer[i](output, src_context, src[:, :, 0].transpose(0, 1), input.transpose(0, 1)) outputs = output.transpose(0, 1).contiguous() if state.previous_input: outputs = outputs[state.previous_input.size(0):] attn = attn[:, state.previous_input.size(0):].squeeze() attn = torch.stack([attn]) attns["std"] = attn if self._copy: attns["copy"] = attn state = TransformerDecoderState(input.unsqueeze(2)) else: assert isinstance(state, RNNDecoderState) output = state.input_feed.squeeze(0) hidden = state.hidden # CHECKS n_batch_, _ = output.size() aeq(n_batch, n_batch_) # END CHECKS coverage = state.coverage.squeeze(0) \ if state.coverage is not None else None # NOTE: something goes wrong when I try to define a "upper_bounds" # variable here -- memory blows up. Apparently the presence of such # variable prevents the computation graph to be deleted after # processing each batch. I need to investigate this further. # A workaround for now is to do one round of softmax (without # upper bound constraints) followed by several rounds of constrained # softmax. # upper_bounds = Variable(torch.ones(attn.size()).cuda()) # Standard RNN decoder. for i, emb_t in enumerate(emb.split(1)): # Initialize upper bounds for the current batch if upper_bounds is None: # if not test: # tgt_lengths = [torch.nonzero(input[:,i].data).size(0) for i in range(n_batch_)] # tgt_lengths = torch.Tensor(tgt_lengths).cuda() # else: # # Maybe the ratio of tgt_len and src_len from training set would be a better estimate # tgt_lengths = torch.ones(n_batch_).cuda() if self.predict_fertility: # comp_tensor = torch.Tensor([float(emb.size(0)) / context.size(0)]).repeat(n_batch_, s_len_).cuda() # comp_tensor = (tgt_lengths/s_len_).unsqueeze(1).repeat(1, s_len_).cuda() # print("fertility_vals:", fertility_vals.data) # max_word_coverage = Variable(torch.max(fertility_vals.data, comp_tensor)) max_word_coverage = fertility_vals.clone() elif self.guided_fertility: # comp_tensor = torch.Tensor([float(emb.size(0)) / context.size(0)]).repeat(n_batch_, s_len_).cuda() # comp_tensor = (tgt_lengths/s_len_).unsqueeze(1).repeat(1, s_len_).cuda() # import pdb; pdb.set_trace() fertility_vals = Variable( evaluation.getBatchFertilities( fert_dict, src).transpose(1, 0).contiguous()) max_word_coverage = fertility_vals # max_word_coverage = Variable(torch.max(fertility_vals, comp_tensor)) elif self.supervised_fertility: # k should be index of first sentence in batch predicted_fertility_vals = fertility_vals true_fertility_vals = fert_sents[k:k + n_batch_] if test: max_word_coverage = predicted_fertility_vals else: max_word_coverage = true_fertility_vals else: # max_word_coverage = max( # self.fertility, float(emb.size(0)) / context.size(0)) max_word_coverage = Variable( torch.Tensor([self.fertility ]).repeat(n_batch_, s_len_)).cuda() # max_word_coverage = Variable(torch.max(torch.FloatTensor([self.fertility]).repeat(n_batch_).cuda(), # tgt_lengths/s_len_).unsqueeze(1).repeat(1, s_len_)) # upper_bounds = -attn + max_word_coverage # else: # upper_bounds -= attn upper_bounds = max_word_coverage # Use <SINK> token for absorbing remaining attention weight # import pdb; pdb.set_trace() upper_bounds[:, -1] = Variable( 100. * torch.ones(upper_bounds.size(0))) # if (upper_bounds.size(0) > torch.sum(torch.sum(upper_bounds, 1)).cpu().data.numpy())[0]: # print("inv sum:", torch.sum(upper_bounds, 1)) # print("att:", attn) emb_t = emb_t.squeeze(0) if self.input_feed: emb_t = torch.cat([emb_t, output], 1) rnn_output, hidden = self.rnn(emb_t, hidden) attn_output, attn = self.attn(rnn_output, context.transpose(0, 1), upper_bounds=upper_bounds) # import pdb; pdb.set_trace() # print_attention = True # if print_attention: # attn_probs = attn.data.cpu().numpy() # for k in range(attn_probs.shape[0]): # print('\t'.join(str(val) for val in list(attn_probs[k, :]))) upper_bounds -= attn # k_attn = 1 # upper_bounds = torch.max(upper_bounds - k_attn * attn, Variable(torch.zeros(upper_bounds.size(0), upper_bounds.size(1)).cuda())) # if np.any(upper_bounds.cpu().data.numpy()<1): # print("upper bounds less than 1.0") # print("attn: ", attn) # print("upper_bounds: ", upper_bounds) if self.context_gate is not None: output = self.context_gate(emb_t, rnn_output, attn_output) output = self.dropout(output) else: output = self.dropout(attn_output) outputs += [output] attns["std"] += [attn] # COVERAGE if self._coverage: coverage = (coverage + attn) if coverage else attn attns["coverage"] += [coverage] # COPY if self._copy: _, copy_attn = self.copy_attn(output, context.transpose(0, 1)) attns["copy"] += [copy_attn] if self.exhaustion_loss: attns["upper_bounds"] += [upper_bounds] if self.supervised_fertility: attns["true_fertility_vals"] += [true_fertility_vals] attns["predicted_fertility_vals"] += [predicted_fertility_vals] state = RNNDecoderState( hidden, output.unsqueeze(0), coverage.unsqueeze(0) if coverage is not None else None, upper_bounds) outputs = torch.stack(outputs) for k in attns: attns[k] = torch.stack(attns[k]) return outputs, state, attns, upper_bounds
def forward(self, input, lengths=None, hidden=None): """ Args: input (LongTensor): len x batch x nfeat lengths (LongTensor): batch hidden: Initial hidden state. Returns: hidden_t (FloatTensor): Pair of layers x batch x rnn_size - final Encoder state outputs (FloatTensor): len x batch x rnn_size - Memory bank """ # CHECKS s_len, n_batch, n_feats = input.size() if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_) # END CHECKS emb = self.embeddings(input) s_len, n_batch, emb_dim = emb.size() if self.encoder_type == "mean": # No RNN, just take mean as final state. mean = emb.mean(0).expand(self.num_layers, n_batch, emb_dim) return (mean, mean), emb elif self.encoder_type == "transformer": # Self-attention tranformer. out = emb.transpose(0, 1).contiguous() words = input[:, :, 0].transpose(0, 1) # CHECKS out_batch, out_len, _ = out.size() w_batch, w_len = words.size() aeq(out_batch, w_batch) aeq(out_len, w_len) # END CHECKS # Make mask. padding_idx = self.embeddings.padding_idx mask = words.data.eq(padding_idx).unsqueeze(1) \ .expand(w_batch, w_len, w_len) # Run the forward pass of every layer of the tranformer. for i in range(self.num_layers): out = self.transformer[i](out, mask) return Variable(emb.data), out.transpose(0, 1).contiguous() elif self.encoder_type == "cnn": out = emb.transpose(0, 1).contiguous() out, emb_remap = self.cnn(out) return emb_remap.transpose(0, 1).contiguous(),\ out.transpose(0, 1).contiguous() else: # Standard RNN encoder. packed_emb = emb if lengths is not None: # Lengths data is wrapped inside a Variable. lengths = lengths.view(-1).tolist() packed_emb = pack(emb, lengths) outputs, hidden_t = self.rnn(packed_emb, hidden) if lengths: outputs = unpack(outputs)[0] return hidden_t, outputs
def forward(self, input, src, context, state, fertility_vals=None, fert_dict=None, fert_sents=None, upper_bounds=None, test=False): """ Forward through the decoder. Args: input (LongTensor): (len x batch) -- Input tokens src (LongTensor) context: (src_len x batch x rnn_size) -- Memory bank state: an object initializing the decoder. Returns: outputs: (len x batch x rnn_size) final_states: an object of the same form as above attns: Dictionary of (src_len x batch) """ # CHECKS t_len, n_batch = input.size() s_len, n_batch_, _ = src.size() s_len_, n_batch__, _ = context.size() aeq(n_batch, n_batch_, n_batch__) # aeq(s_len, s_len_) # END CHECKS emb = self.embeddings(input.unsqueeze(2)) # n.b. you can increase performance if you compute W_ih * x for all iterations in parallel, but that's only possible if self.input_feed=False outputs = [] # Setup the different types of attention. attns = {"std": []} if self._copy: attns["copy"] = [] if self._coverage: attns["coverage"] = [] if self.exhaustion_loss: attns["upper_bounds"] = [] assert isinstance(state, RNNDecoderState) output = state.input_feed.squeeze(0) hidden = state.hidden # CHECKS n_batch_, _ = output.size() aeq(n_batch, n_batch_) # END CHECKS coverage = state.coverage.squeeze( 0) if state.coverage is not None else None for i, emb_t in enumerate(emb.split(1)): # Initialize upper bounds for the current batch if upper_bounds is None: upper_bounds = Variable( torch.Tensor([self.fertility]).repeat(n_batch_, s_len_)).cuda() # Use <SINK> token for absorbing remaining attention weight upper_bounds[:, -1] = Variable(100. * torch.ones(upper_bounds.size(0))) emb_t = emb_t.squeeze(0) if self.input_feed: emb_t = torch.cat([emb_t, output], 1) rnn_output, hidden = self.rnn(emb_t, hidden) attn_output, attn = self.attn(rnn_output, context.transpose(0, 1), upper_bounds=upper_bounds) upper_bounds -= attn if self.context_gate is not None: output = self.context_gate(emb_t, rnn_output, attn_output) output = self.dropout(output) else: output = self.dropout(attn_output) outputs += [output] attns["std"] += [attn] # COVERAGE if self._coverage: coverage = (coverage + attn) if coverage else attn attns["coverage"] += [coverage] # COPY if self._copy: _, copy_attn = self.copy_attn(output, context.transpose(0, 1)) attns["copy"] += [copy_attn] if self.exhaustion_loss: attns["upper_bounds"] += [upper_bounds] state = RNNDecoderState( hidden, output.unsqueeze(0), coverage.unsqueeze(0) if coverage is not None else None, upper_bounds) outputs = torch.stack(outputs) for k in attns: attns[k] = torch.stack(attns[k]) return outputs, state, attns, upper_bounds
def forward(self, input, context, coverage=None): """ input (FloatTensor): batch x dim: decoder's rnn's output. context (FloatTensor): batch x src_len x dim: src hidden states coverage (FloatTensor): batch x src_len """ # Check input sizes batch, sourceL, dim = context.size() batch_, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if self.mask is not None: beam_, batch_, sourceL_ = self.mask.size() aeq(batch, batch_*beam_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) context += self.linear_cover(cover).view_as(context) context = self.tanh(context) # compute attention scores, as in Luong et al. a_t = self.score(input, context) if self.mask is not None: a_t.data.masked_fill_(self.mask, -float('inf')) # Softmax to normalize attention weights align_vector = self.sm(a_t) # the context vector c_t is the weighted average # over all the source hidden states c_t = torch.bmm(align_vector.unsqueeze(1), context).squeeze(1) # concatenate attn_h_t = self.linear_out(torch.cat([c_t, input], 1)) if self.attn_type in ["general", "dot"]: attn_h_t = self.tanh(attn_h_t) # Check output sizes batch_, sourceL_ = align_vector.size() aeq(batch, batch_) aeq(sourceL, sourceL_) batch_, dim_ = attn_h_t.size() aeq(batch, batch_) aeq(dim, dim_) return attn_h_t, align_vector
def forward(self, input, context, coverage=None): """ input (FloatTensor): batch x dim context (FloatTensor): batch x sourceL x dim coverage (FloatTensor): batch x sourceL """ # Check input sizes batch, sourceL, dim = context.size() batch_, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if self.mask: beam_, batch_, sourceL_ = self.mask.size() aeq(batch, batch_ * beam_) aeq(sourceL, sourceL_) if coverage: context += self.linear_cover(coverage.view(-1).unsqueeze(1)) \ .view_as(context) context = self.tanh(context) # Alignment/Attention Function if self.attn_type == "dotprod": # batch x dim x 1 targetT = self.linear_in(input).unsqueeze(2) # batch x sourceL attn = torch.bmm(context, targetT).squeeze(2) elif self.attn_type == "mlp": # batch x 1 x dim wq = self.linear_query(input).unsqueeze(1) # batch x sourceL x dim uh = self.linear_context(context.contiguous()) # batch x sourceL x dim wquh = uh + wq.expand_as(uh) # batch x sourceL x dim wquh = self.mlp_tanh(wquh) # batch x sourceL attn = self.v(wquh.contiguous()).squeeze(2) if self.mask is not None: attn.data.masked_fill_(self.mask, -float('inf')) # SoftMax attn = self.sm(attn) # Compute context weighted by attention. # batch x 1 x sourceL attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x dim weightedContext = torch.bmm(attn3, context).squeeze(1) # Concatenate the input to context (Luong only) weightedContext = torch.cat((weightedContext, input), 1) weightedContext = self.linear_out(weightedContext) # if self.attn_type == "dotprod": weightedContext = self.tanh(weightedContext) # Check output sizes batch_, sourceL_ = attn.size() aeq(batch, batch_) aeq(sourceL, sourceL_) batch_, dim_ = weightedContext.size() aeq(batch, batch_) aeq(dim, dim_) return weightedContext, attn
def forward(self, input, context, coverage=None): """ input (FloatTensor): batch x tgt_len x dim: decoder's rnn's output. context (FloatTensor): batch x src_len x dim: src hidden states coverage (FloatTensor): None (not supported yet) """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = context.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if self.mask is not None: beam_, batch_, sourceL_ = self.mask.size() aeq(batch, batch_ * beam_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) context += self.linear_cover(cover).view_as(context) context = self.tanh(context) # compute attention scores, as in Luong et al. align = self.score(input, context) if self.mask is not None: mask_ = self.mask.view(batch, 1, sourceL) # make it broardcastable align.data.masked_fill_(mask_, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align.view(batch * targetL, sourceL)) align_vectors = align_vectors.view(batch, targetL, sourceL) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, context) # concatenate concat_c = torch.cat([c, input], 2).view(batch * targetL, dim * 2) attn_h = self.linear_out(concat_c).view(batch, targetL, dim) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors
def forward(self, input, context, src_pad_mask, tgt_pad_mask): # Args Checks input_batch, input_len, _ = input.size() contxt_batch, contxt_len, _ = context.size() aeq(input_batch, contxt_batch) src_batch, t_len, s_len = src_pad_mask.size() tgt_batch, t_len_, t_len__ = tgt_pad_mask.size() aeq(input_batch, contxt_batch, src_batch, tgt_batch) aeq(t_len, t_len_, t_len__, input_len) aeq(s_len, contxt_len) # END Args Checks dec_mask = torch.gt( tgt_pad_mask + self.mask[:, :tgt_pad_mask.size(1), :tgt_pad_mask. size(1)].expand_as(tgt_pad_mask), 0) query, attn = self.self_attn(input, input, input, mask=dec_mask) mid, attn = self.context_attn(context, context, query, mask=src_pad_mask) output = self.feed_forward(mid) # CHECKS output_batch, output_len, _ = output.size() aeq(input_len, output_len) aeq(contxt_batch, output_batch) n_batch_, t_len_, s_len_ = attn.size() aeq(input_batch, n_batch_) aeq(contxt_len, s_len_) aeq(input_len, t_len_) # END CHECKS return output, attn
def _run_forward_pass(self, input, context, state): """ Private helper for running the specific RNN forward pass. Must be overriden by all subclasses. Args: input (LongTensor): a sequence of input tokens tensors of size (len x batch x nfeats). context (FloatTensor): output(tensor sequence) from the Encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the Encoder RNN for initializing the decoder. Returns: hidden (FloatTensor): final hidden state from the Decoder. outputs ([FloatTensor]): an array of output of every time step from the Decoder. attns (dict of (str, [FloatTensor]): a dictionary of different type of attention Tensor array of every time step from the Decoder. coverage (FloatTensor, optional): coverage from the Decoder. """ assert not self._copy # TODO, no support yet. assert not self._coverage # TODO, no support yet. # Initialize local and return variables. outputs = [] attns = {"std": []} coverage = None emb = self.embeddings(input) assert emb.dim() == 3 # len x batch x embedding_dim # Run the forward pass of the RNN. rnn_output, hidden = self.rnn(emb, state.hidden) # Reuslt Check input_len, input_batch, _ = input.size() output_len, output_batch, _ = rnn_output.size() aeq(input_len, output_len) aeq(input_batch, output_batch) # END Reuslt Check # Calculate the attention. attn_outputs, attn_scores = self.attn( rnn_output.transpose(0, 1).contiguous(), # (output_len, batch, d) context.transpose(0, 1) # (contxt_len, batch, d) ) attns["std"] = attn_scores # Calculate the context gate. if self.context_gate is not None: outputs = self.context_gate( emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), attn_outputs.view(-1, attn_outputs.size(2))) outputs = outputs.view(input_len, input_batch, self.hidden_size) outputs = self.dropout(outputs) else: outputs = self.dropout(attn_outputs) # (input_len, batch, d) # Return result. return hidden, outputs, attns, coverage
def forward(self, input, lengths=None, hidden=None): """ Args: input (LongTensor): len x batch x nfeat lengths (LongTensor): batch hidden: Initial hidden state. Returns: hidden_t (FloatTensor): Pair of layers x batch x rnn_size - final Encoder state outputs (FloatTensor): len x batch x rnn_size - Memory bank """ # CHECKS s_len, n_batch, n_feats = input.size() if lengths is not None: _, n_batch_ = lengths.size() aeq(n_batch, n_batch_) # END CHECKS emb = self.embeddings(input) s_len, n_batch, vec_size = emb.size() if self.encoder_layer == "mean": # No RNN, just take mean as final state. mean = emb.mean(0) \ .expand(self.layers, n_batch, vec_size) return (mean, mean), emb elif self.encoder_layer == "transformer": # Self-attention tranformer. out = emb.transpose(0, 1).contiguous() for i in range(self.layers): out = self.transformer[i](out, input[:, :, 0].transpose(0, 1)) return Variable(emb.data), out.transpose(0, 1).contiguous() else: # import pdb; pdb.set_trace() # Standard RNN encoder. packed_emb = emb if lengths is not None: # Lengths data is wrapped inside a Variable. lengths = lengths.data.view(-1).tolist() packed_emb = pack(emb, lengths) outputs, hidden_t = self.rnn(packed_emb, hidden) if lengths: outputs = unpack(outputs)[0] if self.predict_fertility: if self.use_sigmoid_fertility: fertility_vals = self.fertility * F.sigmoid( self.fertility_out( torch.cat([ outputs.view( -1, self.hidden_size * self.num_directions), emb.view(-1, vec_size) ], dim=1))) else: fertility_vals = F.relu( self.fertility_linear( torch.cat([ outputs.view( -1, self.hidden_size * self.num_directions), emb.view(-1, vec_size) ], dim=1))) fertility_vals = F.relu( self.fertility_linear_2(fertility_vals)) fertility_vals = 1 + torch.exp( self.fertility_out(fertility_vals)) fertility_vals = fertility_vals.view(n_batch, s_len) # fertility_vals = fertility_vals / torch.sum(fertility_vals, 1).repeat(1, s_len) * s_len elif self.guided_fertility: fertility_vals = None # evaluation.get_fertility() elif self.supervised_fertility: fertility_vals = F.relu( self.sup_linear( outputs.view(-1, self.hidden_size * self.num_directions))) fertility_vals = F.relu(self.sup_linear_2(fertility_vals)) fertility_vals = 1 + torch.exp(fertility_vals) else: fertility_vals = None return hidden_t, outputs, fertility_vals
def forward(self, key, value, query, mask=None): # CHECKS batch, k_len, d = key.size() batch_, k_len_, d_ = value.size() aeq(batch, batch_) aeq(k_len, k_len_) aeq(d, d_) batch_, q_len, d_ = query.size() aeq(batch, batch_) aeq(d, d_) aeq(self.d_model % 8, 0) if mask is not None: batch_, q_len_, k_len_ = mask.size() aeq(batch_, batch) aeq(k_len_, k_len) aeq(q_len_ == q_len) # END CHECKS def shape_projection(x): b, l, d = x.size() return x.view(b, l, self.heads, self.d_k).transpose(1, 2) \ .contiguous().view(b * self.heads, l, self.d_k) def unshape_projection(x, q): b, l, d = q.size() return x.view(b, self.heads, l, self.d_k) \ .transpose(1, 2).contiguous() \ .view(b, l, self.heads * self.d_k) residual = query key_up = shape_projection(self.linear_keys(key)) value_up = shape_projection(self.linear_values(value)) query_up = shape_projection(self.linear_query(query)) scaled = torch.bmm(query_up, key_up.transpose(1, 2)) scaled = scaled / math.sqrt(self.d_k) bh, l, d_k = scaled.size() b = bh // self.heads if mask is not None: scaled = scaled.view(b, self.heads, l, d_k) mask = mask.unsqueeze(1).expand_as(scaled) scaled = scaled.masked_fill(Variable(mask), -float('inf')) \ .view(bh, l, d_k) attn = self.sm(scaled) # Return one attn top_attn = attn.view(b, self.heads, l, d_k)[:, 0, :, :].contiguous() drop_attn = self.dropout(self.sm(scaled)) # values : (batch * 8) x qlen x dim out = unshape_projection(torch.bmm(drop_attn, value_up), residual) # Residual and layer norm res = self.res_dropout(out) + residual ret = self.layer_norm(res) # CHECK batch_, q_len_, d_ = ret.size() aeq(q_len, q_len_) aeq(batch, batch_) aeq(d, d_) # END CHECK return ret, top_attn
def forward(self, input, context, state): """ Forward through the TransformerDecoder. Args: input (LongTensor): a sequence of input tokens tensors of size (len x batch x nfeats). context (FloatTensor): output(tensor sequence) from the Encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the Encoder RNN for initializing the decoder. Returns: outputs (FloatTensor): a Tensor sequence of output from the Decoder of shape (len x batch x hidden_size). state (FloatTensor): final hidden state from the Decoder. attns (dict of (str, FloatTensor)): a dictionary of different type of attention Tensor from the Decoder of shape (src_len x batch). """ # CHECKS assert isinstance(state, TransformerDecoderState) input_len, input_batch, _ = input.size() contxt_len, contxt_batch, _ = context.size() aeq(input_batch, contxt_batch) if state.previous_input is not None: input = torch.cat([state.previous_input, input], 0) src = state.src src_words = src[:, :, 0].transpose(0, 1) tgt_words = input[:, :, 0].transpose(0, 1) src_batch, src_len = src_words.size() tgt_batch, tgt_len = tgt_words.size() aeq(input_batch, contxt_batch, src_batch, tgt_batch) aeq(contxt_len, src_len) aeq(input_len, tgt_len) # END CHECKS # Initialize return variables. outputs = [] attns = {"std": []} if self._copy: attns["copy"] = [] # Run the forward pass of the TransformerDecoder. emb = self.embeddings(input) output = emb.transpose(0, 1).contiguous() src_context = context.transpose(0, 1).contiguous() padding_idx = self.embeddings.padding_idx src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \ .expand(src_batch, tgt_len, src_len) tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \ .expand(tgt_batch, tgt_len, tgt_len) for i in range(self.num_layers): output, attn \ = self.transformer[i](output, src_context, src_pad_mask, tgt_pad_mask) # Process the result and update the attentions. outputs = output.transpose(0, 1).contiguous() if state.previous_input is not None: outputs = outputs[state.previous_input.size(0):] attn = attn[:, state.previous_input.size(0):].squeeze() attn = torch.stack([attn]) attns["std"] = attn if self._copy: attns["copy"] = attn # Update the TransformerDecoderState. state = TransformerDecoderState(src, input) return outputs, state, attns
def forward(self, input, src, context, state): """ Forward through the decoder. Args: input (LongTensor): (len x batch) -- Input tokens src (LongTensor) context: (src_len x batch x rnn_size) -- Memory bank state: an object initializing the decoder. Returns: outputs: (len x batch x rnn_size) final_states: an object of the same form as above attns: Dictionary of (src_len x batch) """ # CHECKS t_len, n_batch = input.size() s_len, n_batch_, _ = src.size() s_len_, n_batch__, _ = context.size() aeq(n_batch, n_batch_, n_batch__) # aeq(s_len, s_len_) # END CHECKS emb = self.embeddings(input.unsqueeze(2)) # n.b. you can increase performance if you compute W_ih * x for all # iterations in parallel, but that's only possible if # self.input_feed=False outputs = [] # Setup the different types of attention. attns = {"std": []} if self._coverage: attns["coverage"] = [] assert isinstance(state, RNNDecoderState) output = state.input_feed.squeeze(0) hidden = state.hidden # CHECKS n_batch_, _ = output.size() aeq(n_batch, n_batch_) # END CHECKS coverage = state.coverage.squeeze(0) \ if state.coverage is not None else None # Standard RNN decoder. for i, emb_t in enumerate(emb.split(1)): emb_t = emb_t.squeeze(0) if self.input_feed: emb_t = torch.cat([emb_t, output], 1) rnn_output, hidden = self.rnn(emb_t, hidden) attn_output, attn = self.attn(rnn_output, context.transpose(0, 1)) output = self.dropout(attn_output) outputs += [output] attns["std"] += [attn] # COVERAGE if self._coverage: coverage = (coverage + attn) if coverage is not None else attn attns["coverage"] += [coverage] state = RNNDecoderState( hidden, output.unsqueeze(0), coverage.unsqueeze(0) if coverage is not None else None) outputs = torch.stack(outputs) for k in attns: attns[k] = torch.stack(attns[k]) return outputs, state, attns
def forward(self, input, src, context, state): """ Forward through the decoder. Args: input (LongTensor): (len x batch) -- Input tokens src (LongTensor) context: (src_len x batch x rnn_size) -- Memory bank state: an object initializing the decoder. Returns: outputs: (len x batch x rnn_size) final_states: an object of the same form as above attns: Dictionary of (src_len x batch) """ # CHECKS t_len, n_batch = input.size() s_len, n_batch_, _ = src.size() s_len_, n_batch__, _ = context.size() aeq(n_batch, n_batch_, n_batch__) # aeq(s_len, s_len_) # END CHECKS if self.decoder_type == "transformer": if state.previous_input: input = torch.cat([state.previous_input.squeeze(2), input], 0) emb = self.embeddings(input.unsqueeze(2)) # n.b. you can increase performance if you compute W_ih * x for all # iterations in parallel, but that's only possible if # self.input_feed=False outputs = [] # Setup the different types of attention. attns = {"std": []} if self._copy: attns["copy"] = [] if self._coverage: attns["coverage"] = [] if self.decoder_type == "transformer": # Tranformer Decoder. assert isinstance(state, TransformerDecoderState) output = emb.transpose(0, 1).contiguous() src_context = context.transpose(0, 1).contiguous() for i in range(self.layers): output, attn \ = self.transformer[i](output, src_context, src[:, :, 0].transpose(0, 1), input.transpose(0, 1)) outputs = output.transpose(0, 1).contiguous() if state.previous_input: outputs = outputs[state.previous_input.size(0):] attn = attn[:, state.previous_input.size(0):].squeeze() attn = torch.stack([attn]) attns["std"] = attn if self._copy: attns["copy"] = attn state = TransformerDecoderState(input.unsqueeze(2)) elif self.input_feed: assert isinstance(state, RNNDecoderState) output = state.input_feed.squeeze(0) hidden = state.hidden # CHECKS n_batch_, _ = output.size() aeq(n_batch, n_batch_) # END CHECKS coverage = state.coverage.squeeze(0) \ if state.coverage is not None else None # Standard RNN decoder. for i, emb_t in enumerate(emb.split(1)): emb_t = emb_t.squeeze(0) if self.input_feed: emb_t = torch.cat([emb_t, output], 1) rnn_output, hidden = self.rnn(emb_t, hidden) attn_output, attn = self.attn(rnn_output, context.transpose(0, 1)) if self.context_gate is not None: output = self.context_gate(emb_t, rnn_output, attn_output) output = self.dropout(output) else: output = self.dropout(attn_output) outputs += [output] attns["std"] += [attn] # COVERAGE if self._coverage: coverage = coverage + attn \ if coverage is not None else attn attns["coverage"] += [coverage] # COPY if self._copy: _, copy_attn = self.copy_attn(output, context.transpose(0, 1)) attns["copy"] += [copy_attn] state = RNNDecoderState( hidden, output.unsqueeze(0), coverage.unsqueeze(0) if coverage is not None else None) outputs = torch.stack(outputs) for k in attns: attns[k] = torch.stack(attns[k]) else: assert isinstance(state, RNNDecoderState) assert emb.dim() == 3 assert not self._coverage assert state.coverage is None # TODO: copy assert not self._copy hidden = state.hidden rnn_output, hidden = self.rnn(emb, hidden) # CHECKS t_len_, n_batch_, _ = rnn_output.size() aeq(n_batch, n_batch_) aeq(t_len, t_len_) # END CHECKS attn_outputs, attn_scores = self.attn( rnn_output.transpose(0, 1).contiguous(), # (batch, t_len, d) context.transpose(0, 1) # (batch, s_len, d) ) if self.context_gate is not None: outputs = self.context_gate( emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), attn_outputs.view(-1, attn_outputs.size(2))) outputs = outputs.view(t_len, n_batch, self.hidden_size) outputs = self.dropout(outputs) else: outputs = self.dropout(attn_outputs) # (t_len, batch, d) state = RNNDecoderState(hidden, outputs[-1].unsqueeze(0), None) attns["std"] = attn_scores return outputs, state, attns
def forward(self, input, context, coverage=None, upper_bounds=None): """ input (FloatTensor): batch x dim context (FloatTensor): batch x sourceL x dim coverage (FloatTensor): batch x sourceL upper_bounds (FloatTensor): batch x sourceL """ # Check input sizes batch, sourceL, dim = context.size() batch_, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if self.mask is not None: beam_, batch_, sourceL_ = self.mask.size() aeq(batch, batch_ * beam_) aeq(sourceL, sourceL_) if coverage: context += self.linear_cover(coverage.view(-1).unsqueeze(1)) \ .view_as(context) context = self.tanh(context) # Alignment/Attention Function if self.attn_type == "dotprod": # batch x dim x 1 targetT = self.linear_in(input).unsqueeze(2) # batch x sourceL attn = torch.bmm(context, targetT).squeeze(2) elif self.attn_type == "mlp": # batch x dim x 1 wq = self.linear_query(input).unsqueeze(1) # batch x sourceL x dim uh = self.linear_context(context.contiguous()) # batch x sourceL x dim wquh = uh + wq.expand_as(uh) # batch x sourceL x dim wquh = self.tanh(wquh) # batch x sourceL #print("self.v: ", self.v.weight) attn = self.v(wquh.contiguous()).squeeze() # EXPERIMENTAL if upper_bounds is not None and 'constrained' in self.attn_transform and self.c_attn != 0.0: indices = torch.arange(0, upper_bounds.size(1) - 1).cuda().long() uu = torch.index_select(upper_bounds.data, 1, indices) attn = attn + self.c_attn * Variable( torch.cat((uu, torch.zeros(upper_bounds.size(0)).cuda()), 1)) if self.mask is not None: attn.data.masked_fill_(self.mask, -float('inf')) if self.attn_transform == 'constrained_softmax': if upper_bounds is None: attn = nn.Softmax()(attn) else: # assert round(np.sum(upper_bounds.cpu().data.numpy()), 5) >= 1.0, pdb.set_trace() attn = self.sm(attn, upper_bounds) elif self.attn_transform == 'constrained_sparsemax': if upper_bounds is None: attn = Sparsemax()(attn) else: attn = self.sm(attn, upper_bounds) else: attn = self.sm(attn) #if upper_bounds is None: # attn = self.sm(attn) #else: # attn = self.sm(attn - upper_bounds) # Compute context weighted by attention. # batch x 1 x sourceL attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x dim weightedContext = torch.bmm(attn3, context).squeeze(1) # Concatenate the input to context (Luong only) if self.attn_type == "dotprod": weightedContext = torch.cat((weightedContext, input), 1) weightedContext = self.linear_out(weightedContext) weightedContext = self.tanh(weightedContext) # Check output sizes batch_, sourceL_ = attn.size() aeq(batch, batch_) aeq(sourceL, sourceL_) batch_, dim_ = weightedContext.size() aeq(batch, batch_) aeq(dim, dim_) return weightedContext, attn