def _run_forward_pass(self, input, context, state, context_lengths=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. output = state.input_feed.squeeze(0) output_batch, _ = output.size() input_len, input_batch, _ = input.size() aeq(input_batch, output_batch) # END Additional args check. # Initialize local and return variables. outputs = [] attns = {"std": []} if self._copy: attns["copy"] = [] if self._coverage: attns["coverage"] = [] emb = self.embeddings(input) assert emb.dim() == 3 # len x batch x embedding_dim hidden = state.hidden coverage = state.coverage.squeeze(0) \ if state.coverage is not None else None # Input feed concatenates hidden state with # input at every time step. for i, emb_t in enumerate(emb.split(1)): emb_t = emb_t.squeeze(0) emb_t = torch.cat([emb_t, output], 1) rnn_output, hidden = self.rnn(emb_t, hidden) attn_output, attn = self.attn(rnn_output, context.transpose(0, 1), context_lengths=context_lengths) if self.context_gate is not None: # TODO: context gate should be employed # instead of second RNN transform. output = self.context_gate(emb_t, rnn_output, attn_output) output = self.dropout(output) else: output = self.dropout(attn_output) outputs += [output] attns["std"] += [attn] # Update the coverage attention. if self._coverage: coverage = coverage + attn \ if coverage is not None else attn attns["coverage"] += [coverage] # Run the forward pass of the copy attention layer. if self._copy: _, copy_attn = self.copy_attn(output, context.transpose(0, 1)) attns["copy"] += [copy_attn] # Return result. return hidden, outputs, attns, coverage
def _example_dict_iter(self, line): line = line.split() if self.line_truncate: line = line[:self.line_truncate] words, feats, n_feats = TextDataset.extract_text_features(line) example_dict = {self.side: words, "indices": self.line_index} if feats: # All examples must have same number of features. aeq(self.n_feats, n_feats) prefix = self.side + "_feat_" example_dict.update( (prefix + str(j), f) for j, f in enumerate(feats)) return example_dict
def forward(self, input, context, state, context_lengths=None): """ Args: input (`LongTensor`): sequences of padded tokens `[tgt_len x batch x nfeats]`. context (`FloatTensor`): vectors from the encoder `[src_len x batch x hidden]`. state (:obj:`onmt.Models.DecoderState`): decoder state object to initialize the decoder context_lengths (`LongTensor`): the padded source lengths `[batch]`. Returns: (`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`): * outputs: output from the decoder `[tgt_len x batch x hidden]`. * state: final hidden state from the decoder * attns: distribution over src at each tgt `[tgt_len x batch x src_len]`. """ # Args Check assert isinstance(state, RNNDecoderState) input_len, input_batch, _ = input.size() contxt_len, contxt_batch, _ = context.size() aeq(input_batch, contxt_batch) # END Args Check # Run the forward pass of the RNN. hidden, outputs, attns, coverage = self._run_forward_pass( input, context, state, context_lengths=context_lengths) # Update the state with the result. final_output = outputs[-1] state.update_state( hidden, final_output.unsqueeze(0), coverage.unsqueeze(0) if coverage is not None else None) # Concatenates sequence of tensors along a new dimension. outputs = torch.stack(outputs) for k in attns: attns[k] = torch.stack(attns[k]) return outputs, state, attns
def coalesce_datasets(datasets): """Coalesce all dataset instances. """ final = datasets[0] for d in datasets[1:]: # `src_vocabs` is a list of `torchtext.vocab.Vocab`. # Each sentence transforms into on Vocab. # Coalesce them into one big list. final.src_vocabs += d.src_vocabs # All datasets have same number of features. aeq(final.n_src_feats, d.n_src_feats) aeq(final.n_tgt_feats, d.n_tgt_feats) # `examples` is a list of `torchtext.data.Example`. # Coalesce them into one big list. final.examples += d.examples # All datasets have same fields, no need to update. return final
def forward(self, hidden, attn, src_map): """ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by compying source words. Args: hidden (`FloatTensor`): hidden outputs `[batch*tlen, input_size]` attn (`FloatTensor`): attn for each `[batch*tlen, input_size]` src_map (`FloatTensor`): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. `[src_len, batch, extra_words]` """ # CHECKS batch_by_tlen, _ = hidden.size() batch_by_tlen_, slen = attn.size() slen_, batch, cvocab = src_map.size() aeq(batch_by_tlen, batch_by_tlen_) aeq(slen, slen_) # Original probabilities. logits = self.linear(hidden) logits[:, self.tgt_dict.stoi[PAD_WORD]] = -float('inf') prob = F.softmax(logits) # Probability of copying p(z=1) batch. copy = F.sigmoid(self.linear_copy(hidden)) # Probibility of not copying: p_{word}(w) * (1 - p(z)) out_prob = torch.mul(prob, 1 - copy.expand_as(prob)) mul_attn = torch.mul(attn, copy.expand_as(attn)) copy_prob = torch.bmm(mul_attn.view(-1, batch, slen) .transpose(0, 1), src_map.transpose(0, 1)).transpose(0, 1) copy_prob = copy_prob.contiguous().view(-1, cvocab) return torch.cat([out_prob, copy_prob], 1)
def forward(self, input): """ Computes the embeddings for words and features. Args: input (`LongTensor`): index tensor `[len x batch x nfeat]` Return: `FloatTensor`: word embeddings `[len x batch x embedding_size]` """ in_length, in_batch, nfeat = input.size() aeq(nfeat, len(self.emb_luts)) emb = self.make_embedding(input) out_length, out_batch, emb_size = emb.size() aeq(in_length, out_length) aeq(in_batch, out_batch) aeq(emb_size, self.embedding_size) return emb
def forward(self, base_target_emb, input, encoder_out_top, encoder_out_combine): """ Args: base_target_emb: target emb tensor input: output of decode conv encoder_out_t: the key matrix for calculation of attetion weight, which is the top output of encode conv encoder_out_combine: the value matrix for the attention-weighted sum, which is the combination of base emb and top output of encode """ # checks batch, channel, height, width = base_target_emb.size() batch_, channel_, height_, width_ = input.size() aeq(batch, batch_) aeq(height, height_) enc_batch, enc_channel, enc_height = encoder_out_top.size() enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() aeq(enc_batch, enc_batch_) aeq(enc_height, enc_height_) preatt = seq_linear(self.linear_in, input) target = (base_target_emb + preatt) * SCALE_WEIGHT target = torch.squeeze(target, 3) target = torch.transpose(target, 1, 2) pre_attn = torch.bmm(target, encoder_out_top) if self.mask is not None: pre_attn.data.masked_fill_(self.mask, -float('inf')) pre_attn = pre_attn.transpose(0, 2) attn = F.softmax(pre_attn) attn = attn.transpose(0, 2).contiguous() context_output = torch.bmm(attn, torch.transpose(encoder_out_combine, 1, 2)) context_output = torch.transpose(torch.unsqueeze(context_output, 3), 1, 2) return context_output, attn
def score(self, h_t, h_s): """ Args: h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]` h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` Returns: :obj:`FloatTensor`: raw attention scores (unnormalized) for each src index `[batch x tgt_len x src_len]` """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) h_t_ = self.linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = self.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def _check_args(self, input, lengths=None, hidden=None): s_len, n_batch, n_feats = input.size() if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_)
def _run_forward_pass(self, input, context, state, context_lengths=None): """ Private helper for running the specific RNN forward pass. Must be overriden by all subclasses. Args: input (LongTensor): a sequence of input tokens tensors of size (len x batch x nfeats). context (FloatTensor): output(tensor sequence) from the encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the encoder RNN for initializing the decoder. context_lengths (LongTensor): the source context lengths. Returns: hidden (Variable): final hidden state from the decoder. outputs ([FloatTensor]): an array of output of every time step from the decoder. attns (dict of (str, [FloatTensor]): a dictionary of different type of attention Tensor array of every time step from the decoder. coverage (FloatTensor, optional): coverage from the decoder. """ assert not self._copy # TODO, no support yet. assert not self._coverage # TODO, no support yet. # Initialize local and return variables. outputs = [] attns = {"std": []} coverage = None emb = self.embeddings(input) # Run the forward pass of the RNN. if isinstance(self.rnn, nn.GRU): rnn_output, hidden = self.rnn(emb, state.hidden[0]) else: rnn_output, hidden = self.rnn(emb, state.hidden) # Result Check input_len, input_batch, _ = input.size() output_len, output_batch, _ = rnn_output.size() aeq(input_len, output_len) aeq(input_batch, output_batch) # END Result Check # Calculate the attention. attn_outputs, attn_scores = self.attn( rnn_output.transpose(0, 1).contiguous(), # (output_len, batch, d) context.transpose(0, 1), # (contxt_len, batch, d) context_lengths=context_lengths) attns["std"] = attn_scores # Calculate the context gate. if self.context_gate is not None: outputs = self.context_gate( emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), attn_outputs.view(-1, attn_outputs.size(2))) outputs = outputs.view(input_len, input_batch, self.hidden_size) outputs = self.dropout(outputs) else: outputs = self.dropout(attn_outputs) # (input_len, batch, d) # Return result. return hidden, outputs, attns, coverage
def forward(self, input, context, state, context_lengths=None): """ See :obj:`onmt.modules.RNNDecoderBase.forward()`""" # CHECKS assert isinstance(state, CNNDecoderState) input_len, input_batch, _ = input.size() contxt_len, contxt_batch, _ = context.size() aeq(input_batch, contxt_batch) # END CHECKS if state.previous_input is not None: input = torch.cat([state.previous_input, input], 0) # Initialize return variables. outputs = [] attns = {"std": []} assert not self._copy, "Copy mechanism not yet tested in conv2conv" if self._copy: attns["copy"] = [] emb = self.embeddings(input) assert emb.dim() == 3 # len x batch x embedding_dim tgt_emb = emb.transpose(0, 1).contiguous() # The output of CNNEncoder. src_context_t = context.transpose(0, 1).contiguous() # The combination of output of CNNEncoder and source embeddings. src_context_c = state.init_src.transpose(0, 1).contiguous() # Run the forward pass of the CNNDecoder. emb_reshape = tgt_emb.contiguous().view( tgt_emb.size(0) * tgt_emb.size(1), -1) linear_out = self.linear(emb_reshape) x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1) x = shape_transform(x) pad = Variable( torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1)) pad = pad.type_as(x) base_target_emb = x for conv, attention in zip(self.conv_layers, self.attn_layers): new_target_input = torch.cat([pad, x], 2) out = conv(new_target_input) c, attn = attention(base_target_emb, out, src_context_t, src_context_c) x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT output = x.squeeze(3).transpose(1, 2) # Process the result and update the attentions. outputs = output.transpose(0, 1).contiguous() if state.previous_input is not None: outputs = outputs[state.previous_input.size(0):] attn = attn[:, state.previous_input.size(0):].squeeze() attn = torch.stack([attn]) attns["std"] = attn if self._copy: attns["copy"] = attn # Update the state. state.update_state(input) return outputs, state, attns
def forward(self, key, value, query, mask=None): """ Compute the context vector and the attention vectors. Args: key (`FloatTensor`): set of `key_len` key vectors `[batch, key_len, dim]` value (`FloatTensor`): set of `key_len` value vectors `[batch, key_len, dim]` query (`FloatTensor`): set of `query_len` query vectors `[batch, query_len, dim]` mask: binary mask indicating which keys have non-zero attention `[batch, query_len, key_len]` Returns: (`FloatTensor`, `FloatTensor`) : * output context vectors `[batch, query_len, dim]` * one of the attention vectors `[batch, query_len, key_len]` """ # CHECKS batch, k_len, d = key.size() batch_, k_len_, d_ = value.size() aeq(batch, batch_) aeq(k_len, k_len_) aeq(d, d_) batch_, q_len, d_ = query.size() aeq(batch, batch_) aeq(d, d_) aeq(self.model_dim % 8, 0) if mask is not None: batch_, q_len_, k_len_ = mask.size() aeq(batch_, batch) aeq(k_len_, k_len) aeq(q_len_ == q_len) # END CHECKS def shape_projection(x): b, l, d = x.size() return x.view(b, l, self.head_count, self.dim_per_head) \ .transpose(1, 2).contiguous() \ .view(b * self.head_count, l, self.dim_per_head) def unshape_projection(x, q): b, l, d = q.size() return x.view(b, self.head_count, l, self.dim_per_head) \ .transpose(1, 2).contiguous() \ .view(b, l, self.head_count * self.dim_per_head) residual = query key_up = shape_projection(self.linear_keys(key)) value_up = shape_projection(self.linear_values(value)) query_up = shape_projection(self.linear_query(query)) scaled = torch.bmm(query_up, key_up.transpose(1, 2)) scaled = scaled / math.sqrt(self.dim_per_head) bh, l, dim_per_head = scaled.size() b = bh // self.head_count if mask is not None: scaled = scaled.view(b, self.head_count, l, dim_per_head) mask = mask.unsqueeze(1).expand_as(scaled) scaled = scaled.masked_fill(Variable(mask), -1e18) \ .view(bh, l, dim_per_head) attn = self.sm(scaled) # Return one attn top_attn = attn \ .view(b, self.head_count, l, dim_per_head)[:, 0, :, :] \ .contiguous() drop_attn = self.dropout(self.sm(scaled)) # values : (batch * 8) x qlen x dim out = unshape_projection(torch.bmm(drop_attn, value_up), residual) # Residual and layer norm ret = self.res_dropout(out) # CHECK batch_, q_len_, d_ = ret.size() aeq(q_len, q_len_) aeq(batch, batch_) aeq(d, d_) # END CHECK return ret, top_attn
def forward(self, input, context, context_lengths=None, coverage=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` context (`FloatTensor`): source vectors `[batch x src_len x dim]` context_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = context.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) context += self.linear_cover(cover).view_as(context) context = self.tanh(context) # compute attention scores, as in Luong et al. align = self.score(input, context) if context_lengths is not None: mask = sequence_mask(context_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align.view(batch * targetL, sourceL)) align_vectors = align_vectors.view(batch, targetL, sourceL) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, context) # concatenate concat_c = torch.cat([c, input], 2).view(batch * targetL, dim * 2) attn_h = self.linear_out(concat_c).view(batch, targetL, dim) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors