def _set_input_buffer(self, incremental_state, buffer): utils.set_incremental_state( self, incremental_state, 'attn_state', buffer, )
def _set_input_buffer(self, incremental_state, buffer, incremental_clone_id=""): self.incremental_clone_ids.add(incremental_clone_id) utils.set_incremental_state(self, incremental_state, "attn_state" + incremental_clone_id, buffer)
def forward(self, x, incremental_state=None, encoder_lstm_states=None, **unused): residual = x x = self.maybe_layer_norm(x, before=True) seqlen, bsz, _ = x.size() cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: num_layers = len(self.layers) if encoder_lstm_states is not None: encoder_hiddens, encoder_cells = encoder_lstm_states prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)] else: state_size = bsz, self.hidden_dim prev_hiddens = [ x.new_zeros(*state_size) for _ in range(num_layers) ] prev_cells = [ x.new_zeros(*state_size) for _ in range(num_layers) ] # prev_hiddens = [x.new_zeros(*state_size) for i in range(num_layers)] # prev_cells = [x.new_zeros(*state_size) for i in range(num_layers)] input_feed = x.new_zeros(bsz, self.hidden_dim) outs = [] for j in range(seqlen): input = torch.cat((x[j, :, :], input_feed), dim=1) for i, rnn in enumerate(self.layers): hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # input = F.dropout(hidden, p=self.dropout, training=self.training) input = hidden prev_hiddens[i] = hidden prev_cells[i] = cell out = hidden input_feed = out outs.append(out) utils.set_incremental_state(self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed)) x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_dim) x = self.linear(x) x = F.dropout(x, p=self.dropout, training=self.training) x = x + residual x = self.maybe_layer_norm(x, after=True) return x, None
def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out') if encoder_out is not None: encoder_out = tuple( eo.index_select(0, new_order) for eo in encoder_out) utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
def forward_unprojected(self, input_tokens, encoder_out, incremental_state=None): padded_tokens = F.pad( input_tokens, (self.history_len - 1, 0, 0, 0), "constant", self.dst_dict.eos(), ) # We use incremental_state only to check whether we are decoding or not # self.training is false even for the forward pass through validation if incremental_state is not None: padded_tokens = padded_tokens[:, -self.history_len:] utils.set_incremental_state(self, incremental_state, "incremental_marker", True) bsz, seqlen = padded_tokens.size() seqlen -= self.history_len - 1 # get outputs from encoder (encoder_outs, final_hidden, _, src_lengths, _) = encoder_out # padded_tokens has shape [batch_size, seq_len+history_len] x = self.embed_tokens(padded_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) # Convolution needs shape [batch_size, channels, seq_len] x = self.history_conv(x.transpose(1, 2)).transpose(1, 2) x = F.dropout(x, p=self.dropout_out, training=self.training) # x has shape [batch_size, seq_len, channels] for i, layer in enumerate(self.layers): prev_x = x x = layer(x) x = F.dropout(x, p=self.dropout_out, training=self.training) if self.residual_level is not None and i >= self.residual_level: x = x + prev_x # Attention attn_out, attn_scores = self.attention( x.transpose(0, 1).contiguous().view(-1, self.hidden_dim), encoder_outs.repeat(1, seqlen, 1), src_lengths.repeat(seqlen), ) if attn_out is not None: attn_out = attn_out.view(seqlen, bsz, -1).transpose(1, 0) attn_scores = attn_scores.view(-1, seqlen, bsz).transpose(0, 2) x = maybe_cat((x, attn_out), dim=2) # bottleneck layer if hasattr(self, "additional_fc"): x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) return x, attn_scores
def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is None: return def reorder_state(state): if isinstance(state, list): return [reorder_state(state_i) for state_i in state] return state.index_select(0, new_order) if state is not None else None new_state = tuple(map(reorder_state, cached_state)) utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): encoder_padding_mask = encoder_out['encoder_padding_mask'].t() encoder_out = encoder_out['encoder_out'] if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() srclen = encoder_out.size(0) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, _ = cached_state else: prev_hiddens = [encoder_out.mean(dim=0) for i in range(self.num_layers)] prev_cells = [encoder_out.mean(dim=0) for i in range(self.num_layers)] # get outputs from encoder #encoder_out = self.layer_norm(encoder_out) x = torch.stack([encoder_out[:,i,:].index_select(0, prev_output_tokens[i]) for i in range(encoder_out.size(1))], dim=1) x = self.layer_norm(x) x = F.dropout(x, p=self.dropout, training=self.training) #encoder_out = F.dropout(encoder_out, p=self.dropout, training=self.training) attn_scores = x.new_zeros(bsz, seqlen, srclen) for j in range(seqlen): input = x[j, :, :] for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = F.dropout(hidden, p=self.dropout, training=self.training) # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # apply attention using the last layer's hidden state attn_scores[:, j, :] = self.attention(hidden, encoder_out, encoder_padding_mask) # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, None), ) return attn_scores, None
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): if incremental_state is not None: # If the *incremental_state* argument is not ``None`` then we are # in incremental inference mode. While *prev_output_tokens* will # still contain the entire decoded prefix, we will only use the # last step and assume that the rest of the state is cached. prev_output_tokens = prev_output_tokens[:, -1:] # This remains the same as before. bsz, tgt_len = prev_output_tokens.size() final_encoder_hidden = encoder_out['final_hidden'] final_encoder_hidden = final_encoder_hidden[0:bsz, :] x = self.embed_tokens(prev_output_tokens) x = self.dropout(x) x = torch.cat( [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)], dim=2, ) # We will now check the cache and load the cached previous hidden and # cell states, if they exist, otherwise we will initialize them to # zeros (as before). We will use the ``utils.get_incremental_state()`` # and ``utils.set_incremental_state()`` helpers. initial_state = utils.get_incremental_state( self, incremental_state, 'prev_state', ) if initial_state is None: # first time initialization, same as the original version initial_state = ( final_encoder_hidden.unsqueeze(0), # hidden torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell ) # Run one step of our LSTM. output, latest_state = self.lstm(x.transpose(0, 1), initial_state) # Update the cache with the latest hidden and cell states. utils.set_incremental_state( self, incremental_state, 'prev_state', latest_state, ) # This remains the same as before x = output.transpose(0, 1) x = self.output_projection(x) return x, None
def reorder_incremental_state(self, incremental_state, new_order): cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is None: return def reorder_state(state): if isinstance(state, list): return [reorder_state(state_i) for state_i in state] return state.index_select(0, new_order) if not isinstance(new_order, Variable): new_order = Variable(new_order) new_state = tuple(map(reorder_state, cached_state)) utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def set_pointer(self, incremental_state, p_choose): curr_pointer = self.get_pointer(incremental_state) if len(curr_pointer) == 0: buffer = torch.zeros_like(p_choose) else: buffer = self.get_pointer(incremental_state)["step"] buffer += (p_choose < 0.5).type_as(buffer) utils.set_incremental_state( self, incremental_state, 'monotonic', {"step": buffer}, )
def reorder_incremental_state(self, incremental_state, new_order): """Reorder buffered internal state (for incremental generation).""" cached_state = utils.get_incremental_state(self, incremental_state, "cached_state") if cached_state is None: return def reorder_state(state): if isinstance(state, list): return [reorder_state(state_i) for state_i in state] return state.index_select(0, new_order) new_state = tuple(map(reorder_state, cached_state)) utils.set_incremental_state(self, incremental_state, "cached_state", new_state)
def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) cumsum_probs = utils.get_incremental_state(self, incremental_state, 'cumsum_probs') if cumsum_probs is not None: new_cumsum_probs = cumsum_probs.index_select(0, new_order) utils.set_incremental_state(self, incremental_state, 'cumsum_probs', new_cumsum_probs) nodes = utils.get_incremental_state(self, incremental_state, 'nodes') if nodes is not None: new_nodes = nodes.index_select(0, new_order) utils.set_incremental_state(self, incremental_state, 'nodes', new_nodes)
def reorder_incremental_state(self, incremental_state, new_order): cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state prev_hiddens = [ hidden.index_select(0, new_order) for hidden in prev_hiddens ] prev_cells = [ cell.index_select(0, new_order) for cell in prev_cells ] input_feed = input_feed.index_select(0, new_order) utils.set_incremental_state(self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed))
def reorder_incremental_state(self, incremental_state, new_order): # Load the cached state. prev_state = utils.get_incremental_state( self, incremental_state, 'prev_state', ) # Reorder batches according to *new_order*. reordered_state = ( prev_state[0].index_select(1, new_order), # hidden prev_state[1].index_select(1, new_order), # cell ) # Update the cached state. utils.set_incremental_state( self, incremental_state, 'prev_state', reordered_state, )
def _set_input_buffer( self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_buffer, ): return utils.set_incremental_state(self, incremental_state, "input_buffer", new_buffer)
def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) cumsum_probs = utils.get_incremental_state(self, incremental_state, 'cumsum_probs') if cumsum_probs is not None: new_cumsum_probs = cumsum_probs.index_select(0, new_order) utils.set_incremental_state(self, incremental_state, 'cumsum_probs', new_cumsum_probs) nodes = utils.get_incremental_state(self, incremental_state, 'nodes') if nodes is not None: new_order_list = new_order.tolist() new_nodes = [nodes[i] for i in new_order_list] utils.set_incremental_state(self, incremental_state, 'nodes', new_nodes)
def reorder_incremental_state(self, incremental_state, new_order): # parent reorders attention model super().reorder_incremental_state(incremental_state, new_order) cached_state = utils.get_incremental_state(self, incremental_state, "cached_state") if cached_state is None: return # Last 2 elements of prev_states are encoder projections # used for ONNX export for i, state in enumerate(cached_state[:-2]): cached_state[i] = state.index_select(1, new_order) utils.set_incremental_state(self, incremental_state, "cached_state", cached_state)
def _split_encoder_out(self, encoder_out, incremental_state): """Split and transpose encoder outputs. This is cached when doing incremental inference. """ cached_result = utils.get_incremental_state(self, incremental_state, 'encoder_out') if cached_result is not None: return cached_result # transpose only once to speed up attention layers encoder_a, encoder_b = encoder_out encoder_a = encoder_a.transpose(1, 2).contiguous() result = (encoder_a, encoder_b) if incremental_state is not None: utils.set_incremental_state(self, incremental_state, 'encoder_out', result) return result
def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) for state_name in ['wordlm_logprobs', 'out_logprobs', 'subword_cumlogprobs']: state = utils.get_incremental_state(self, incremental_state, state_name) if state is not None: new_state = state.index_select(0, new_order) utils.set_incremental_state( self, incremental_state, state_name, new_state, ) nodes = utils.get_incremental_state(self, incremental_state, 'nodes') if nodes is not None: new_order_list = new_order.tolist() new_nodes = [nodes[i] for i in new_order_list] utils.set_incremental_state( self, incremental_state, 'nodes', new_nodes, )
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bbsz = prev_output_tokens.size(0) vocab = len(self.dictionary) src_len = encoder_out.encoder_out.size(1) tgt_len = prev_output_tokens.size(1) # determine number of steps if incremental_state is not None: # cache step number step = utils.get_incremental_state(self, incremental_state, "step") if step is None: step = 0 utils.set_incremental_state(self, incremental_state, "step", step + 1) steps = [step] else: steps = list(range(tgt_len)) # define output in terms of raw probs if hasattr(self.args, "probs"): assert (self.args.probs.dim() == 3 ), "expected probs to have size bsz*steps*vocab" probs = self.args.probs.index_select(1, torch.LongTensor(steps)) else: probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_() for i, step in enumerate(steps): # args.beam_probs gives the probability for every vocab element, # starting with eos, then unknown, and then the rest of the vocab if step < len(self.args.beam_probs): probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step] else: probs[:, i, self.dictionary.eos()] = 1.0 # random attention attn = torch.rand(bbsz, tgt_len, src_len) dev = prev_output_tokens.device return probs.to(dev), {"attn": [attn.to(dev)]}
def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is None: return #EDITED def reorder_state(state, idx): if isinstance(state, list) or isinstance(state, tuple): return [reorder_state(state_i, idx) for state_i in state] return state.index_select(idx, new_order) new_state = [ reorder_state(sub, idx) for (sub, idx) in zip(cached_state, [0, 0, 0, 1]) ] utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def reorder_incremental_state(self, incremental_state, new_order): """ The ``FairseqIncrementalDecoder`` interface also requires implementing a ``reorder_incremental_state()`` method, which is used during beam search to select and reorder the incremental state. """ # Load the cached state. prev_state = utils.get_incremental_state( self, incremental_state, 'prev_state', ) # Reorder batches according to *new_order*. reordered_state = ( prev_state[0].index_select(1, new_order), # hidden prev_state[1].index_select(1, new_order), # cell ) # Update the cached state. utils.set_incremental_state( self, incremental_state, 'prev_state', reordered_state, )
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bbsz = prev_output_tokens.size(0) vocab = len(self.dictionary) src_len = encoder_out.size(1) tgt_len = prev_output_tokens.size(1) # determine number of steps if incremental_state is not None: # cache step number step = utils.get_incremental_state(self, incremental_state, 'step') if step is None: step = 0 utils.set_incremental_state(self, incremental_state, 'step', step + 1) steps = [step] else: steps = list(range(tgt_len)) # define output in terms of raw probs if hasattr(self.args, 'probs'): assert self.args.probs.dim() == 3, \ 'expected probs to have size bsz*steps*vocab' probs = self.args.probs.index_select(1, torch.LongTensor(steps)) else: probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_() for i, step in enumerate(steps): # args.beam_probs gives the probability for every vocab element, # starting with eos, then unknown, and then the rest of the vocab if step < len(self.args.beam_probs): probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step] else: probs[:, i, self.dictionary.eos()] = 1.0 # random attention attn = torch.rand(bbsz, tgt_len, src_len) return probs, attn
def reorder_incremental_state(self, incremental_state, new_order): def apply_reorder_incremental_state(module): if module != self and hasattr(module, 'reorder_incremental_state'): module.reorder_incremental_state( incremental_state, new_order, ) self.apply(apply_reorder_incremental_state) # document decoder cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state_rnn') if cached_state is None: return def reorder_state(state): if isinstance(state, list): return [reorder_state(state_i) for state_i in state] return state.index_select(0, new_order) new_state = tuple(map(reorder_state, cached_state)) utils.set_incremental_state(self, incremental_state, 'cached_state_rnn', new_state)
def masked_copy_incremental_state(self, incremental_state, another_state, mask): state = utils.get_incremental_state(self, incremental_state, 'encoder_out') if state is None: assert another_state is None return def mask_copy_state(state, another_state): if isinstance(state, list): assert isinstance(another_state, list) and len(state) == len(another_state) return [mask_copy_state(state_i, another_state_i) \ for state_i, another_state_i in zip(state, another_state)] if state is not None: assert state.size(0) == mask.size(0) and another_state is not None and \ state.size() == another_state.size() for _ in range(1, len(state.size())): mask_unsqueezed = mask.unsqueeze(-1) return torch.where(mask_unsqueezed, state, another_state) else: assert another_state is None return None new_state = tuple(map(mask_copy_state, state, another_state)) utils.set_incremental_state(self, incremental_state, 'encoder_out', new_state)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # get outputs from encoder encoder_outs, _, _ = encoder_out[:3] srclen = encoder_outs.size(0) # embed tokens x = self.embed_tokens(prev_output_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: _, encoder_hiddens, encoder_cells = encoder_out[:3] num_layers = len(self.layers) prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)] input_feed = Variable( x.data.new(bsz, self.encoder_output_units).zero_()) attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_()) outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step input = torch.cat((x[j, :, :], input_feed), dim=1) for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = F.dropout(hidden, p=self.dropout_out, training=self.training) # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # apply attention using the last layer's hidden state if self.attention is not None: out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs) else: out = hidden out = F.dropout(out, p=self.dropout_out, training=self.training) # input feeding input_feed = out # save final output outs.append(out) # cache previous states (no-op except during incremental generation) utils.set_incremental_state(self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed)) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) # T x B x C -> B x T x C x = x.transpose(1, 0) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen attn_scores = attn_scores.transpose(0, 2) # project back to size of vocabulary if hasattr(self, 'additional_fc'): x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) x = self.fc_out(x) return x, attn_scores
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): padded_tokens = F.pad( prev_output_tokens, (self.history_len - 1, 0, 0, 0), "constant", self.dst_dict.eos(), ) # We use incremental_state only to check whether we are decoding or not # self.training is false even for the forward pass through validation if incremental_state is not None: padded_tokens = padded_tokens[:, -self.history_len - 1:] utils.set_incremental_state(self, incremental_state, "incremental_marker", True) bsz, seqlen = padded_tokens.size() seqlen -= self.history_len - 1 # get outputs from encoder (encoder_outs, final_hidden, _, src_lengths, _) = encoder_out # padded_tokens has shape [batch_size, seq_len+history_len] x = self.embed_tokens(padded_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) # Convolution needs shape [batch_size, channels, seq_len] x = self.history_conv(x.transpose(1, 2)).transpose(1, 2) x = F.dropout(x, p=self.dropout_out, training=self.training) # x has shape [batch_size, seq_len, channels] for i, layer in enumerate(self.layers): prev_x = x x = layer(x) x = F.dropout(x, p=self.dropout_out, training=self.training) if self.residual_level is not None and i >= self.residual_level: x = x + prev_x # Attention attn_out, attn_scores = self.attention( x.transpose(0, 1).contiguous().view(-1, self.hidden_dim), encoder_outs.repeat(1, seqlen, 1), src_lengths.repeat(seqlen), ) attn_out = attn_out.view(seqlen, bsz, -1).transpose(1, 0) attn_scores = attn_scores.view(-1, seqlen, bsz).transpose(0, 2) x = torch.cat((x, attn_out), dim=2) # bottleneck layer if hasattr(self, "additional_fc"): x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) output_projection_w = self.output_projection_w output_projection_b = self.output_projection_b # avoiding transpose of projection weights during ONNX tracing batch_time_hidden = torch.onnx.operators.shape_as_tensor(x) x_flat_shape = torch.cat( (torch.LongTensor([-1]), batch_time_hidden[2].view(1))) x_flat = torch.onnx.operators.reshape_from_tensor_shape( x, x_flat_shape) projection_flat = torch.matmul(output_projection_w, x_flat.t()).t() logits_shape = torch.cat( (batch_time_hidden[:2], torch.LongTensor([-1]))) logits = (torch.onnx.operators.reshape_from_tensor_shape( projection_flat, logits_shape) + output_projection_b) return logits, attn_scores, None
def setCached(self, incremental_state, key, value): utils.set_incremental_state(self, incremental_state, key, value)
def forward(self, decoder_in_dict, encoder_out_dict, incremental_state=None, incr_doc_step=False, batch_idxs=None, new_incr_cached=None): prev_output_tokens = decoder_in_dict['prev_output_tokens'] encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] # [b x w] if encoder_padding_mask is not None: encoder_padding_mask = encoder_padding_mask.transpose(0, 1) srcbsz, srcdoclen, srcdim = encoder_out_dict['encoder_out'][0].size() # [b x w x d] # summarise whole input for h0 decoder use, verbose but clearer src_summary_h0 = encoder_out_dict['encoder_out'][0].mean(1) # [b x d] bsz, doclen, sentlen = prev_output_tokens.size() # these sizes are target ones start_doc = 0 if incremental_state is not None: doclen = 1 # get initial input embedding for document RNN decoder x = prev_output_tokens.data.new(bsz).fill_(self.sod_idx) x = self.decoder.embed_tokens(x) ## Decode sentence states ## # initialize previous states (or get from cache during incremental generation) cached_state_rnn = utils.get_incremental_state(self, incremental_state, 'cached_state_rnn') if incr_doc_step and cached_state_rnn is not None: # doing the fist step of the ith (i>1) sentence in incremental generation prev_hiddens, prev_cells, input = cached_state_rnn outs = [input] elif incremental_state is not None \ and new_incr_cached is not None: # doing subsequents steps of a sentence in incremental generation bidxs, old_bsz, reorder_state = batch_idxs if reorder_state is not None: # need to do this when some hypotheses have been finished when generating # reducing decoding to lower nb of hypotheses new_incr_cached = new_incr_cached.index_select(0, reorder_state) outs = [new_incr_cached] else: outs = [new_incr_cached] else: # first state of first sentence in incremental generation or # or first comming here to generate the whole sentence in trainin/scoring # previous is h0 with encoder output summary outs = [] encoder_hiddens_cells = src_summary_h0 # [b x d] prev_hiddens = [] prev_cells = [] for i in range(len(self.layers)): prev_hiddens.append(encoder_hiddens_cells) prev_cells.append(encoder_hiddens_cells) input = x # attn of document decoder over input aggregated units (e.g. encoded sequence of paragraphs) attn_scores = x.data.new(srcdoclen, doclen, bsz).zero_() if (incremental_state is not None and incr_doc_step) \ or incremental_state is None: for j in range(start_doc, doclen): for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = hidden # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # apply attention using the last layer's hidden state (sentence vector) if self.wordAttention is not None: # inputs to attention are of the form # input: bsz x input_embed_dim # source_hids: srclen x bsz x output_embed_dim # either attend to the input representation by the cnn encoder or to its combination with the input embeddings if self.hidemb: attn_h, attn_scores_out = self.wordAttention(hidden, \ encoder_out_dict['encoder_out'][1].transpose(0, 1),\ encoder_padding_mask) else: attn_h, attn_scores_out = self.wordAttention(hidden, \ encoder_out_dict['encoder_out'][0].transpose(0, 1), \ encoder_padding_mask) out = attn_h # [b x d] else: out = hidden # input to next time step input = out new_incr_cached = out.clone() # save final output if incremental_state is not None: outs.append(out) else: outs.append(out.unsqueeze(1)) ## Decode sentences ## # When training/validation, make all sentence s_t decoding steps in parallel here sent_states = None if incremental_state is not None: dec_outs = x.data.new(bsz, doclen, sentlen, len(self.decoder.dictionary)).zero_() # decode by sentence s_j for j in range(doclen): sp = self.embed_sent_positions(decoder_in_dict['sentence_position']) dec_out, word_atte_scores = self.decoder( prev_output_tokens[:, j, :], outs[j], sp, encoder_out_dict, incremental_state, firstfeed=self.firstfeed, normpos=self.normpos) # prev_output_tokens is [ b x s x w ], at each time step decode sentence j [b x w] # dec_out is [b x w x vocabulary] if j == 0: dec_outs = dec_out else: dec_outs = torch.cat((dec_outs, dec_out), 1) # dec_outs is [bxs x w x vocabulary], dim=0 # dec_outs is [b x s*w x vocabulary], dim=1 else: # decode everything in parallel sent_states = torch.cat(outs, dim=1).view(bsz*doclen, -1) ys = prev_output_tokens.view(bsz*doclen, -1) sp = make_sent_positions(prev_output_tokens, self.padding_idx).view(bsz*doclen) sp = self.embed_sent_positions(sp) # Replicate encoder_out_dict for the new nb of batches to do all in parallel ebsz, esrclen, edim = encoder_out_dict['encoder_out'][0].size() new_enc_out_dict = {} #repeat input for each target new_enc_out_dict['encoder_out'] = ( encoder_out_dict['encoder_out'][0].view(ebsz, 1, esrclen, edim).expand(ebsz, doclen, esrclen, edim) .contiguous().view(ebsz*doclen, esrclen, edim), encoder_out_dict['encoder_out'][1].view(ebsz, 1, esrclen, edim).expand(ebsz, doclen, esrclen, edim) .contiguous().view(ebsz*doclen, esrclen, edim) ) new_enc_out_dict['encoder_padding_mask'] = None if encoder_out_dict['encoder_padding_mask'] is not None: new_enc_out_dict['encoder_padding_mask'] = encoder_out_dict['encoder_padding_mask']\ .view(ebsz, 1, esrclen).expand(ebsz, doclen, esrclen)\ .contiguous().view(ebsz*doclen, -1) #decode all target sentences of all documents in parallel dec_out, word_atte_scores = self.decoder(ys, sent_states, sp, new_enc_out_dict, firstfeed=self.firstfeed, normpos=self.normpos) dec_outs = dec_out.view(bsz, doclen*sentlen, len(self.decoder.dictionary)) if incremental_state is not None and incr_doc_step: # only if we moved to the next document sentence # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, 'cached_state_rnn', (prev_hiddens, prev_cells, out)) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen attn_scores = attn_scores.transpose(0, 2) return dec_outs, (attn_scores, word_atte_scores), new_incr_cached, None # topic label predictions are None here
def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): """ Similar to *forward* but only return features. Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - attention weights of shape `(batch, tgt_len, src_len)` """ if self.attention is not None: assert encoder_out is not None encoder_padding_mask = encoder_out['encoder_padding_mask'] encoder_out = encoder_out['encoder_out'] # get outputs from encoder encoder_outs = encoder_out[0] srclen = encoder_outs.size(0) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # embed tokens x = self.embed_tokens(prev_output_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: num_layers = len(self.layers) prev_hiddens = [x.new_zeros(bsz, self.hidden_size) \ for i in range(num_layers)] prev_cells = [x.new_zeros(bsz, self.hidden_size) \ for i in range(num_layers)] input_feed = x.new_zeros(bsz, self.encoder_output_units) \ if self.attention is not None else None if self.attention is not None: attn_scores = x.new_zeros(srclen, seqlen, bsz) outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step input = torch.cat((x[j, :, :], input_feed), dim=1) \ if input_feed is not None else x[j, :, :] for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) if self.residual and i > 0: # residual connection starts from the 2nd layer prev_layer_hidden = input[:, :hidden.size(1)] # compute and apply attention using the 1st layer's hidden state if self.attention is not None: if i == 0: context, attn_scores[:, j, :], _ = self.attention( hidden, encoder_outs, encoder_padding_mask) # hidden state concatenated with context vector becomes the # input to the next layer input = torch.cat((hidden, context), dim=1) else: input = hidden input = F.dropout(input, p=self.dropout_out, training=self.training) if self.residual and i > 0: if self.attention is not None: hidden_sum = input[:, :hidden. size(1)] + prev_layer_hidden input = torch.cat( (hidden_sum, input[:, hidden.size(1):]), dim=1) else: input = input + prev_layer_hidden # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # input feeding input_feed = context if self.attention is not None else None # save final output outs.append(input) # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed), ) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, -1) assert x.size(2) == self.hidden_size + self.encoder_output_units # T x B x C -> B x T x C x = x.transpose(1, 0) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen if not self.training and self.attention is not None and self.need_attn: attn_scores = attn_scores.transpose(0, 2) else: attn_scores = None return x, attn_scores
def _set_input_buffer(self, incremental_state, new_buffer): return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
def _forward_given_embeddings( self, embed_out, prev_output_tokens, encoder_out, incremental_state=None, possible_translation_tokens=None, timestep=None, ): x = embed_out (encoder_x, src_tokens, encoder_padding_mask) = self._unpack_encoder_out(encoder_out) bsz, seqlen = prev_output_tokens.size() state_outputs = [] if incremental_state is not None: prev_states = utils.get_incremental_state(self, incremental_state, "cached_state") if prev_states is None: prev_states = self._init_prev_states(encoder_out) # final 2 states of list are projected key and value saved_state = { "prev_key": prev_states[-2], "prev_value": prev_states[-1] } self.attention._set_input_buffer(incremental_state, saved_state) if incremental_state is not None: # first num_layers pairs of states are (prev_hidden, prev_cell) # for each layer h_prev = prev_states[0] c_prev = prev_states[1] else: h_prev = self._init_hidden(encoder_out, bsz).type_as(x) c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x) x = self._concat_latent_code(x, encoder_out) x, (h_next, c_next) = self.initial_rnn_layer(x, (h_prev, c_prev)) if incremental_state is not None: state_outputs.extend([h_next, c_next]) x = F.dropout(x, p=self.dropout, training=self.training) if self.proj_encoder_layer is not None: encoder_x = self.proj_encoder_layer(encoder_x) attention_in = x if self.proj_layer is not None: attention_in = self.proj_layer(x) attention_out, attention_weights = self.attention( query=attention_in, key=encoder_x, value=encoder_x, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training), ) for i, layer in enumerate(self.extra_rnn_layers): residual = x rnn_input = torch.cat([x, attention_out], dim=2) rnn_input = self._concat_latent_code(rnn_input, encoder_out) if incremental_state is not None: # first num_layers pairs of states are (prev_hidden, prev_cell) # for each layer h_prev = prev_states[2 * i + 2] c_prev = prev_states[2 * i + 3] else: h_prev = self._init_hidden(encoder_out, bsz).type_as(x) c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x) x, (h_next, c_next) = layer(rnn_input, (h_prev, c_prev)) if incremental_state is not None: state_outputs.extend([h_next, c_next]) x = F.dropout(x, p=self.dropout, training=self.training) x = x + residual x = torch.cat([x, attention_out], dim=2) x = self._concat_latent_code(x, encoder_out) if self.bottleneck_layer is not None: x = self.bottleneck_layer(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if (self.vocab_reduction_module is not None and possible_translation_tokens is None): decoder_input_tokens = prev_output_tokens.contiguous() possible_translation_tokens = self.vocab_reduction_module( src_tokens, decoder_input_tokens=decoder_input_tokens) output_weights = self.embed_out if possible_translation_tokens is not None: output_weights = output_weights.index_select( dim=0, index=possible_translation_tokens) logits = F.linear(x, output_weights) if incremental_state is not None: # encoder projections can be reused at each incremental step state_outputs.extend([prev_states[-2], prev_states[-1]]) utils.set_incremental_state(self, incremental_state, "cached_state", state_outputs) return logits, attention_weights, possible_translation_tokens