def forward(self, prev_output_tokens, encoder_out, incremental_state=None): if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bbsz = prev_output_tokens.size(0) vocab = len(self.dictionary) src_len = encoder_out.size(1) tgt_len = prev_output_tokens.size(1) # determine number of steps if incremental_state is not None: # cache step number step = utils.get_incremental_state(self, incremental_state, 'step') if step is None: step = 0 utils.set_incremental_state(self, incremental_state, 'step', step + 1) steps = [step] else: steps = list(range(tgt_len)) # define output in terms of raw probs if hasattr(self.args, 'probs'): assert self.args.probs.dim() == 3, \ 'expected probs to have size bsz*steps*vocab' probs = self.args.probs.index_select(1, torch.LongTensor(steps)) else: probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_() for i, step in enumerate(steps): # args.beam_probs gives the probability for every vocab element, # starting with eos, then unknown, and then the rest of the vocab if step < len(self.args.beam_probs): probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step] else: probs[:, i, self.dictionary.eos()] = 1.0 # random attention attn = torch.rand(bbsz, tgt_len, src_len) return probs, attn
def forward( self, input_tokens, encoder_out, incremental_state=None, possible_translation_tokens=None, ): if incremental_state is not None: input_tokens = input_tokens[:, -1:] bsz, seqlen = input_tokens.size() # get outputs from encoder (encoder_outs, final_hidden, final_cell, src_lengths, src_tokens) = encoder_out # embed tokens x = self.embed_tokens(input_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, "cached_state") if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: # first time step, initialize previous states prev_hiddens, prev_cells = self._init_prev_states(encoder_out) input_feed = self.initial_attn_context.expand( bsz, self.encoder_hidden_dim) attn_scores_per_step = [] outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step if self.attention is not None: step_input = torch.cat((x[j, :, :], input_feed), dim=1) else: step_input = x[j, :, :] previous_layer_input = step_input for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(step_input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer layer_output = F.dropout(hidden, p=self.dropout_out, training=self.training) if self.residual_level is not None and i >= self.residual_level: # TODO add an assert related to sizes here step_input = layer_output + previous_layer_input else: step_input = layer_output previous_layer_input = step_input # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell if self.attention is not None: out, step_attn_scores = self.attention( hidden, encoder_outs, src_lengths, ) input_feed = out else: combined_output_and_context = hidden step_attn_scores = Variable( torch.ones(src_lengths.shape[0], src_lengths.max()).type_as(encoder_outs, ), requires_grad=False, ).t() attn_scores_per_step.append(step_attn_scores.unsqueeze(1)) attn_scores = torch.cat(attn_scores_per_step, dim=1) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen attn_scores = attn_scores.transpose(0, 2) combined_output_and_context = torch.cat((hidden, out), dim=1) # save final output outs.append(combined_output_and_context) # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, "cached_state", (prev_hiddens, prev_cells, input_feed), ) # collect outputs across time steps x = torch.cat(outs, dim=0).view( seqlen, bsz, self.combined_output_and_context_dim, ) # T x B x C -> B x T x C x = x.transpose(1, 0) # bottleneck layer if hasattr(self, "additional_fc"): x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) output_projection_w = self.output_projection_w output_projection_b = self.output_projection_b decoder_input_tokens = input_tokens if self.training else None if self.vocab_reduction_module and possible_translation_tokens is None: possible_translation_tokens = self.vocab_reduction_module( src_tokens, decoder_input_tokens=decoder_input_tokens) if possible_translation_tokens is not None: output_projection_w = output_projection_w.index_select( dim=0, index=possible_translation_tokens) output_projection_b = output_projection_b.index_select( dim=0, index=possible_translation_tokens) # avoiding transpose of projection weights during ONNX tracing batch_time_hidden = torch.onnx.operators.shape_as_tensor(x) x_flat_shape = torch.cat( (torch.LongTensor([-1]), batch_time_hidden[2].view(1))) x_flat = torch.onnx.operators.reshape_from_tensor_shape( x, x_flat_shape) projection_flat = torch.matmul(output_projection_w, x_flat.t()).t() logits_shape = torch.cat( (batch_time_hidden[:2], torch.LongTensor([-1]))) logits = (torch.onnx.operators.reshape_from_tensor_shape( projection_flat, logits_shape) + output_projection_b) return logits, attn_scores, possible_translation_tokens
def forward(self, input_token, target_token, timestep, *inputs): """ Decoder step inputs correspond one-to-one to encoder outputs. """ log_probs_per_model = [] state_outputs = [] next_state_input = len(self.models) # underlying assumption is each model has same vocab_reduction_module vocab_reduction_module = self.models[0].decoder.vocab_reduction_module if vocab_reduction_module is not None: possible_translation_tokens = inputs[len(self.models)] next_state_input += 1 else: possible_translation_tokens = None for i, model in enumerate(self.models): encoder_output = inputs[i] prev_hiddens = [] prev_cells = [] for _ in range(len(model.decoder.layers)): prev_hiddens.append(inputs[next_state_input]) prev_cells.append(inputs[next_state_input + 1]) next_state_input += 2 prev_input_feed = inputs[next_state_input].view(1, -1) next_state_input += 1 # no batching, we only care about care about "max" length src_length_int = int(encoder_output.size()[0]) src_length = torch.LongTensor(np.array([src_length_int])) # notional, not actually used for decoder computation src_tokens = torch.LongTensor(np.array([[0] * src_length_int])) src_embeddings = encoder_output.new_zeros(encoder_output.shape) encoder_out = ( encoder_output, prev_hiddens, prev_cells, src_length, src_tokens, src_embeddings, ) # store cached states, use evaluation mode model.decoder._is_incremental_eval = True model.eval() # placeholder incremental_state = {} # cache previous state inputs utils.set_incremental_state( model.decoder, incremental_state, "cached_state", (prev_hiddens, prev_cells, prev_input_feed), ) decoder_output = model.decoder( input_token.view(1, 1), encoder_out, incremental_state=incremental_state, possible_translation_tokens=possible_translation_tokens, ) logits, _, _ = decoder_output log_probs = F.log_softmax(logits, dim=2) log_probs_per_model.append(log_probs) (next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state( model.decoder, incremental_state, "cached_state") for h, c in zip(next_hiddens, next_cells): state_outputs.extend([h, c]) state_outputs.append(next_input_feed) average_log_probs = torch.mean(torch.cat(log_probs_per_model, dim=0), dim=0, keepdim=True) if possible_translation_tokens is not None: reduced_indices = torch.zeros(self.vocab_size).long().fill_( self.unk_token) # ONNX-exportable arange (ATen op) possible_translation_token_range = torch._dim_arange( like=possible_translation_tokens, dim=0) reduced_indices[ possible_translation_tokens] = possible_translation_token_range reduced_index = reduced_indices.index_select(dim=0, index=target_token) score = average_log_probs.view( (-1, )).index_select(dim=0, index=reduced_index) else: score = average_log_probs.view( (-1, )).index_select(dim=0, index=target_token) word_reward = self.word_rewards.index_select(0, target_token) score += word_reward self.input_names = ["prev_token", "target_token", "timestep"] for i in range(len(self.models)): self.input_names.append(f"fixed_input_{i}") if possible_translation_tokens is not None: self.input_names.append("possible_translation_tokens") outputs = [score] self.output_names = ["score"] for i in range(len(self.models)): self.output_names.append(f"fixed_input_{i}") outputs.append(inputs[i]) if possible_translation_tokens is not None: self.output_names.append("possible_translation_tokens") outputs.append(possible_translation_tokens) for i, state in enumerate(state_outputs): outputs.append(state) self.output_names.append(f"state_output_{i}") self.input_names.append(f"state_input_{i}") return tuple(outputs)
def extract_features( self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused, ): """ Similar to *forward* but only return features. Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - attention weights of shape `(batch, tgt_len, src_len)` """ if self.attention is not None: assert encoder_out is not None encoder_padding_mask = encoder_out['encoder_padding_mask'] encoder_out = encoder_out['encoder_out'] # get outputs from encoder encoder_outs = encoder_out[0] srclen = encoder_outs.size(0) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # embed tokens x = self.embed_tokens(prev_output_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: num_layers = len(self.layers) prev_hiddens = [x.new_zeros(bsz, self.hidden_size) for i in range(num_layers)] prev_cells = [x.new_zeros(bsz, self.hidden_size) for i in range(num_layers)] input_feed = x.new_zeros(bsz, self.encoder_output_units) \ if self.attention is not None else None if self.attention is not None: attn_scores = x.new_zeros(srclen, seqlen, bsz) outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step input = torch.cat((x[j, :, :], input_feed), dim=1) \ if input_feed is not None else x[j, :, :] for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) if self.residual and i > 0: # residual connection starts from the 2nd layer prev_layer_hidden = input[:, :hidden.size(1)] # compute and apply attention using the 1st layer's hidden state if self.attention is not None: if i == 0: context, attn_scores[:, j, :], _ = self.attention( hidden, encoder_outs, encoder_padding_mask, ) # hidden state concatenated with context vector becomes the # input to the next layer input = torch.cat((hidden, context), dim=1) else: input = hidden input = F.dropout(input, p=self.dropout_out, training=self.training) if self.residual and i > 0: if self.attention is not None: hidden_sum = input[:, :hidden.size(1)] + prev_layer_hidden input = torch.cat((hidden_sum, input[:, hidden.size(1):]), dim=1) else: input = input + prev_layer_hidden # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # input feeding input_feed = context if self.attention is not None else None # save final output outs.append(input) # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed), ) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, -1) assert x.size(2) == self.hidden_size + self.encoder_output_units # T x B x C -> B x T x C x = x.transpose(1, 0) if hasattr(self, 'additional_fc') and self.adaptive_softmax is None: x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen if not self.training and self.attention is not None and self.need_attn: attn_scores = attn_scores.transpose(0, 2) else: attn_scores = None return x, attn_scores
def forward( self, prev_output_tokens: torch.Tensor, # Z_Tokens[Batch, SeqLength] encoder_out=None, incremental_state: Dict[str, Any] = None): assert incremental_state is not None, 'This model is for incremental decoding only' prev_output_tokens = prev_output_tokens[:, -1:] # Z_Tokens[Batch, Len=1] bsz = prev_output_tokens.size(0) if prev_output_tokens.device != self.tree.word_idx.device: self.tree.to_cuda(device=prev_output_tokens.device) # Move the batched state to the next state according to the automaton batch_space_mask = prev_output_tokens.squeeze(-1).eq( self.subword_space_idx) # B[Batch] cached_state = utils.get_incremental_state(self.lm_decoder, incremental_state, 'cached_state') if cached_state is None: # First step assert (prev_output_tokens == self.subword_eos_idx).all(), \ 'expecting the input to the first time step to be <eos>' w: torch.Tensor = prev_output_tokens.new_full( [bsz, 1], self.word_eos_idx) # Z[Batch, Len=1] lm_probs: torch.Tensor = self.lm_decoder.get_normalized_probs( self.lm_decoder(w, incremental_state=incremental_state), log_probs=False, sample=None) # R[Batch, 1, Vocab] cumsum_probs: torch.Tensor = lm_probs.cumsum( dim=-1) # R[Batch, 1, Vocab] nodes: torch.Tensor = prev_output_tokens.new_full( [bsz], self.tree.root_id) # Z_NodeId[Batch] all_children = self.tree.children[ nodes, :] # Z[Batch, PossibleChildren] else: # Not the first step cumsum_probs: torch.Tensor = utils.get_incremental_state( self, incremental_state, 'cumsum_probs') # R[Batch, 1, Vocab] nodes: torch.Tensor = utils.get_incremental_state( self, incremental_state, 'nodes') # Z_NodeId[Batch] assert nodes.size(0) == bsz w: torch.Tensor = self.tree.word_idx[nodes].unsqueeze( 1) # Z[Batch, Len=1] w[w < 0] = self.word_unk_idx old_cached_state = _clone_cached_state(cached_state) # recompute cumsum_probs from inter-word transition probabilities # only for those whose prev_output_token is <space> lm_probs: torch.Tensor = self.lm_decoder.get_normalized_probs( self.lm_decoder(w, incremental_state=incremental_state), log_probs=False, sample=None) # R[Batch, 1, Vocab] self.lm_decoder.masked_copy_incremental_state( incremental_state, old_cached_state, batch_space_mask) # restore those not masked cumsum_probs[batch_space_mask] = lm_probs.cumsum( dim=-1)[batch_space_mask] prev_all_children = self.tree.children[ nodes, :] # Z[Batch, PossibleChildren] prev_possible_tokens = self.tree.prev_subword_idx[ prev_all_children] # Z[Batch, PossibleChildren] # intra-word transition: go to child; oov transition: go to "None" node mask = prev_possible_tokens.eq( prev_output_tokens.expand_as(prev_possible_tokens)) nodes: torch.Tensor = (prev_all_children * mask.long()).sum( dim=1) # Z[Batch] # inter-word transition: go back to root nodes[batch_space_mask] = self.tree.root_id # Z[Batch] all_children = self.tree.children[ nodes, :] # Z[Batch, PossibleChildren] utils.set_incremental_state(self, incremental_state, 'cumsum_probs', cumsum_probs) utils.set_incremental_state(self, incremental_state, 'nodes', nodes) # Compute probabilities # initialize out_probs [Batch, 1, Vocab] if self.open_vocab: # set out_probs to oov_penalty * P(<unk>|h) (case 3 in Eqn. 15) out_probs = self.oov_penalty * ( cumsum_probs[:, :, self.word_unk_idx] - cumsum_probs[:, :, self.word_unk_idx - 1] ).unsqueeze(-1).repeat(1, 1, self.subword_vocab_size) # set the probability of emitting <space> to 0 if prev_output_tokens # is <space> or <eos>, and that of emitting <eos> to 0 if # prev_output_tokens is not <space> batch_space_eos_mask = batch_space_mask | \ prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx) out_probs[batch_space_eos_mask, :, self.subword_space_idx] = self.zero out_probs[~batch_space_mask, :, self.subword_eos_idx] = self.zero # set transition probability to 1 for those whose node is out of the # tree, i.e. node is None (case 4 in Eqn. 15) batch_node_none_mask = nodes.eq(self.tree.none_id) # B[Batch] out_probs[batch_node_none_mask] = 1. else: # set out_probs to 0 out_probs = cumsum_probs.new_full( [bsz, 1, self.subword_vocab_size], self.zero) # compute parent probabilities for those whose node is not None left_ranges = self.tree.word_set_idx[nodes, 0] # Z[Batch] right_ranges = self.tree.word_set_idx[nodes, 1] # Z[Batch] batch_node_not_root_mask = nodes.ne(self.tree.none_id) & nodes.ne( self.tree.root_id) # B[Batch] sum_probs = torch.where( batch_node_not_root_mask, (cumsum_probs.squeeze(1).gather(-1, right_ranges.unsqueeze(-1)) - cumsum_probs.squeeze(1).gather( -1, left_ranges.unsqueeze(-1))).squeeze(-1), cumsum_probs.new([1.0])) # R[Batch] # compute transition probabilities to child nodes (case 2 in Eqn. 15) left_ranges_of_all_children = self.tree.word_set_idx[ all_children, 0] # Z[Batch, PossibleChildren] right_ranges_of_all_children = self.tree.word_set_idx[ all_children, 1] # Z[Batch, PossibleChildren] cumsum_probs_of_all_children = ( cumsum_probs.squeeze(1).gather(-1, right_ranges_of_all_children) - cumsum_probs.squeeze(1).gather(-1, left_ranges_of_all_children) ).unsqueeze(1) / sum_probs.unsqueeze(-1).unsqueeze( -1) # R[Batch, 1, PossibleChildren] cumsum_probs_of_all_children[sum_probs < self.zero, :, :] = self.zero next_possible_tokens = self.tree.prev_subword_idx[ all_children] # Z[Batch, PossibleChildren] out_probs.scatter_(-1, next_possible_tokens.unsqueeze(1), cumsum_probs_of_all_children) # assume self.subword_pad_idx is the padding index in self.tree.prev_subword_idx out_probs[:, :, self.subword_pad_idx] = self.zero # apply word-level probabilities for <space> (case 1 in Eqn. 15) word_idx = self.tree.word_idx[nodes] # Z[Batch] batch_node_word_end_mask = word_idx.ge(0) # B[Batch] # get rid of -1's (word idx of root or non-terminal states). It doesn't # matter what the "dummy" index it would be replaced with (as it will # always be masked out by batch_node_word_end_mask), as long as it is > 0 word_idx[word_idx < 0] = 1 word_probs = torch.where( sum_probs < self.zero, cumsum_probs.new([self.zero]), (cumsum_probs.squeeze(1).gather(-1, word_idx.unsqueeze(-1)) - cumsum_probs.squeeze(1).gather( -1, word_idx.unsqueeze(-1) - 1)).squeeze(-1) / sum_probs) # R[Batch] out_probs[batch_node_word_end_mask, 0, self.subword_space_idx] = \ word_probs[batch_node_word_end_mask] # take log of probs and clip it from below to avoid log(0) out_logprobs = torch.log( torch.max(out_probs, out_probs.new([self.zero]))) # assign log-probs of emitting word <eos> to that of emitting subword <eos> out_logprobs[batch_space_mask, :, self.subword_eos_idx] = \ torch.log(lm_probs)[batch_space_mask, :, self.word_eos_idx] utils.set_incremental_state(self, incremental_state, 'out_logprobs', out_logprobs) # note that here we return log-probs rather than logits, and the second # element is None, which is usually a tensor of attention weights in # attention-based models return out_logprobs, None
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # get outputs from encoder encoder_outs, _, _ = encoder_out srclen = encoder_outs.size(0) # embed tokens x = self.embed_tokens(prev_output_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) embed_dim = x.size(2) # B x T x C -> T x B x C x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: _, encoder_hiddens, encoder_cells = encoder_out num_layers = len(self.layers) prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)] input_feed = Variable(x.data.new(bsz, embed_dim).zero_()) attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_()) outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step input = torch.cat((x[j, :, :], input_feed), dim=1) for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = F.dropout(hidden, p=self.dropout_out, training=self.training) # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # apply attention using the last layer's hidden state out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs) out = F.dropout(out, p=self.dropout_out, training=self.training) # input feeding input_feed = out # save final output outs.append(out) # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed)) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim) # T x B x C -> B x T x C x = x.transpose(1, 0) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen attn_scores = attn_scores.transpose(0, 2) # project back to size of vocabulary if hasattr(self, 'additional_fc'): x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) x = self.fc_out(x) return x, attn_scores
def forward(self, prev_output_tokens, shapes=None, tgt_tok_bounds=None, sort_order=None, encoder_out=None, src_lengths=None, incremental_state=None, **kwargs): if incremental_state is not None: for i in range(len(self.decoders)): es = utils.get_incremental_state(self, incremental_state, 'decoder-' + str(i + 1)) if es is None: utils.set_incremental_state(self, incremental_state, 'decoder-' + str(i + 1), {}) char_flag = ( not self.args.token_sequences) and self.args.char_sequences if len(self.decoders) > 1: g_shapes = shapes if not char_flag else None toks_prev_output_tokens = split_on_sep(prev_output_tokens[0], self.sequence_separator, shapes=g_shapes) g_shapes = shapes if char_flag else None char_prev_output_tokens = split_on_sep(prev_output_tokens[1], self.sequence_separator, shapes=g_shapes) assert len(toks_prev_output_tokens) == len(self.decoders) assert len(char_prev_output_tokens) == len(self.decoders) else: toks_prev_output_tokens = [prev_output_tokens[0]] char_prev_output_tokens = [prev_output_tokens[1]] outputs = [] decoder_input = [encoder_out] if self.training: self.decoder_hidden_states = [[] for i in range(len(self.decoders))] for i in range(self.first_decoder, self.last_decoder): incremental_state_i = utils.get_incremental_state( self, incremental_state, 'decoder-' + str(i + 1)) if not self.training and self.first_decoder > 0 and i == self.first_decoder: assert len(decoder_input) == 1 for d_idx in range(0, i): if self.args.model_type == 'lstm': new_decoder_input = { 'encoder_out': (torch.cat(self.decoder_hidden_states[d_idx], 0), None, None), 'encoder_padding_mask': None } decoder_input.append(new_decoder_input) else: new_decoder_input = EncoderOut( encoder_out=torch.cat( self.decoder_hidden_states[d_idx], 0), encoder_padding_mask=None, encoder_embedding=None, encoder_states=None) decoder_input.append(new_decoder_input) feats_only = False prev_output_tokens_i = [ toks_prev_output_tokens[i], char_prev_output_tokens[i] ] src_lengths_i = [src_lengths[0][0], src_lengths[1][0] ] if src_lengths is not None else [[], []] decoder_out = self.decoders[i]( prev_output_tokens_i, tgt_tok_bounds[i], sort_order, encoder_out=decoder_input, features_only=feats_only, incremental_state=incremental_state_i, src_lengths=src_lengths_i if src_lengths is not None else None, return_all_hiddens=False, ) outputs.append(decoder_out) hidden_state = decoder_out[1]['hidden'].transpose(0, 1) if not self.training: self.decoder_hidden_states[i].append(hidden_state) if self.args.model_type == 'transformer': new_decoder_input = EncoderOut( encoder_out=hidden_state, encoder_padding_mask=None, encoder_embedding=None, encoder_states=None, ) decoder_input.append(new_decoder_input) else: new_decoder_input = { 'encoder_out': (hidden_state, None, None), 'encoder_padding_mask': None } decoder_input.append(new_decoder_input) if not self.training and len(outputs) != len(self.decoders): assert len(outputs) == self.last_decoder - self.first_decoder output_clone = outputs[-1] for o_idx in range(0, len(self.decoders) - len(outputs)): outputs = [output_clone] + outputs x = [] attn_scores = [] for o in outputs: x.append(o[0]) attn_scores.append(o[1]) return x, attn_scores
def forward( self, prev_output_tokens, encoder_out, incremental_state=None, possible_translation_tokens=None, timestep=None, ): (encoder_x, src_tokens, encoder_padding_mask) = self._unpack_encoder_out(encoder_out) bsz, seqlen = prev_output_tokens.size() if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] x = self.embed_tokens(prev_output_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) state_outputs = [] if incremental_state is not None: prev_states = utils.get_incremental_state(self, incremental_state, "cached_state") if prev_states is None: prev_states = self._init_prev_states(encoder_out) # final 2 states of list are projected key and value saved_state = { "prev_key": prev_states[-2], "prev_value": prev_states[-1] } self.attention._set_input_buffer(incremental_state, saved_state) if incremental_state is not None: # first num_layers pairs of states are (prev_hidden, prev_cell) # for each layer h_prev = prev_states[0] c_prev = prev_states[1] else: h_prev = self._init_hidden(encoder_out, bsz).type_as(x) c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x) x = self._concat_latent_code(x, encoder_out) x, (h_next, c_next) = self.initial_rnn_layer(x, (h_prev, c_prev)) if incremental_state is not None: state_outputs.extend([h_next, c_next]) x = F.dropout(x, p=self.dropout, training=self.training) attention_in = x if self.proj_layer is not None: attention_in = self.proj_layer(x) attention_out, attention_weights = self.attention( query=attention_in, key=encoder_x, value=encoder_x, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training), ) for i, layer in enumerate(self.extra_rnn_layers): residual = x rnn_input = torch.cat([x, attention_out], dim=2) rnn_input = self._concat_latent_code(rnn_input, encoder_out) if incremental_state is not None: # first num_layers pairs of states are (prev_hidden, prev_cell) # for each layer h_prev = prev_states[2 * i + 2] c_prev = prev_states[2 * i + 3] else: h_prev = self._init_hidden(encoder_out, bsz).type_as(x) c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x) x, (h_next, c_next) = layer(rnn_input, (h_prev, c_prev)) if incremental_state is not None: state_outputs.extend([h_next, c_next]) x = F.dropout(x, p=self.dropout, training=self.training) x = x + residual x = torch.cat([x, attention_out], dim=2) x = self._concat_latent_code(x, encoder_out) x = self.bottleneck_layer(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if (self.vocab_reduction_module is not None and possible_translation_tokens is None): decoder_input_tokens = prev_output_tokens.contiguous() possible_translation_tokens = self.vocab_reduction_module( src_tokens, decoder_input_tokens=decoder_input_tokens) output_weights = self.embed_out if possible_translation_tokens is not None: output_weights = output_weights.index_select( dim=0, index=possible_translation_tokens) logits = F.linear(x, output_weights) if incremental_state is not None: # encoder projections can be reused at each incremental step state_outputs.extend([prev_states[-2], prev_states[-1]]) utils.set_incremental_state(self, incremental_state, "cached_state", state_outputs) return logits, attention_weights, possible_translation_tokens
def getCached(self, incremental_state, key): x = utils.get_incremental_state(self, incremental_state, key) return x
def forward(self, decoder_in_dict, encoder_out_dict, incremental_state=None, incr_doc_step=False, batch_idxs=None, new_incr_cached=None): prev_output_tokens = decoder_in_dict['prev_output_tokens'] encoder_padding_mask = encoder_out_dict[ 'encoder_padding_mask'] # [b x w] if encoder_padding_mask is not None: encoder_padding_mask = encoder_padding_mask.transpose(0, 1) srcbsz, srcdoclen, srcdim = encoder_out_dict['encoder_out'][0].size( ) # [b x w x d] # summarise whole input for h0 decoder use, verbose but clearer src_summary_h0 = encoder_out_dict['encoder_out'][0].mean(1) # [b x d] bsz, doclen, sentlen = prev_output_tokens.size( ) # these sizes are target ones start_doc = 0 if incremental_state is not None: doclen = 1 # get initial input embedding for document RNN decoder x = prev_output_tokens.data.new(bsz).fill_(self.sod_idx) x = self.decoder.embed_tokens(x) ## Decode sentence states ## # initialize previous states (or get from cache during incremental generation) cached_state_rnn = utils.get_incremental_state(self, incremental_state, 'cached_state_rnn') if incr_doc_step and cached_state_rnn is not None: # doing the fist step of the ith (i>1) sentence in incremental generation prev_hiddens, prev_cells, input = cached_state_rnn outs = [input] elif incremental_state is not None \ and new_incr_cached is not None: # doing subsequents steps of a sentence in incremental generation bidxs, old_bsz, reorder_state = batch_idxs if reorder_state is not None: # need to do this when some hypotheses have been finished when generating # reducing decoding to lower nb of hypotheses new_incr_cached = new_incr_cached.index_select( 0, reorder_state) outs = [new_incr_cached] else: outs = [new_incr_cached] else: # first state of first sentence in incremental generation or # or first coming here to generate the whole sentence in training/scoring # previous is h0 with encoder output summary outs = [] encoder_hiddens_cells = src_summary_h0 # [b x d] prev_hiddens = [] prev_cells = [] for i in range(len(self.layers)): prev_hiddens.append(encoder_hiddens_cells) prev_cells.append(encoder_hiddens_cells) input = x # attn of document decoder over input aggregated units (e.g. encoded sequence of paragraphs) attn_scores = x.data.new(srcdoclen, doclen, bsz).zero_() if (incremental_state is not None and incr_doc_step) \ or incremental_state is None: for j in range(start_doc, doclen): for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = hidden # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # apply attention using the last layer's hidden state (sentence vector) if self.wordAttention is not None: # inputs to attention are of the form # input: bsz x input_embed_dim # source_hids: srclen x bsz x output_embed_dim # either attend to the input representation by the cnn encoder or to its combination with the input embeddings if self.hidemb: attn_h, attn_scores_out = self.wordAttention(hidden, \ encoder_out_dict['encoder_out'][1].transpose(0, 1),\ encoder_padding_mask) else: attn_h, attn_scores_out = self.wordAttention(hidden, \ encoder_out_dict['encoder_out'][0].transpose(0, 1), \ encoder_padding_mask) out = attn_h # [b x d] else: out = hidden # input to next time step input = out new_incr_cached = out.clone() # save final output if incremental_state is not None: outs.append(out) else: outs.append(out.unsqueeze(1)) ## Decode sentences ## # When training/validation, make all sentence s_t decoding steps in parallel here sent_states = None if incremental_state is not None: dec_outs = x.data.new(bsz, doclen, sentlen, len(self.decoder.dictionary)).zero_() # decode by sentence s_j for j in range(doclen): sp = self.embed_sent_positions( decoder_in_dict['sentence_position']) dec_out, word_atte_scores = self.decoder( prev_output_tokens[:, j, :], outs[j], sp, encoder_out_dict, incremental_state, firstfeed=self.firstfeed, normpos=self.normpos) # prev_output_tokens is [ b x s x w ], at each time step decode sentence j [b x w] # dec_out is [b x w x vocabulary] if j == 0: dec_outs = dec_out else: dec_outs = torch.cat((dec_outs, dec_out), 1) # dec_outs is [bxs x w x vocabulary], dim=0 # dec_outs is [b x s*w x vocabulary], dim=1 else: # decode everything in parallel sent_states = torch.cat(outs, dim=1).view(bsz * doclen, -1) ys = prev_output_tokens.view(bsz * doclen, -1) sp = make_sent_positions(prev_output_tokens, self.padding_idx).view(bsz * doclen) sp = self.embed_sent_positions(sp) # Replicate encoder_out_dict for the new nb of batches to do all in parallel ebsz, esrclen, edim = encoder_out_dict['encoder_out'][0].size() new_enc_out_dict = {} #repeat input for each target new_enc_out_dict['encoder_out'] = ( encoder_out_dict['encoder_out'][0].view( ebsz, 1, esrclen, edim).expand(ebsz, doclen, esrclen, edim).contiguous().view( ebsz * doclen, esrclen, edim), encoder_out_dict['encoder_out'][1].view( ebsz, 1, esrclen, edim).expand(ebsz, doclen, esrclen, edim).contiguous().view( ebsz * doclen, esrclen, edim)) new_enc_out_dict['encoder_padding_mask'] = None if encoder_out_dict['encoder_padding_mask'] is not None: new_enc_out_dict['encoder_padding_mask'] = encoder_out_dict['encoder_padding_mask']\ .view(ebsz, 1, esrclen).expand(ebsz, doclen, esrclen)\ .contiguous().view(ebsz*doclen, -1) #decode all target sentences of all documents in parallel dec_out, word_atte_scores = self.decoder(ys, sent_states, sp, new_enc_out_dict, firstfeed=self.firstfeed, normpos=self.normpos) dec_outs = dec_out.view(bsz, doclen * sentlen, len(self.decoder.dictionary)) if incremental_state is not None and incr_doc_step: # only if we moved to the next document sentence # cache previous states (no-op except during incremental generation) utils.set_incremental_state(self, incremental_state, 'cached_state_rnn', (prev_hiddens, prev_cells, out)) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen attn_scores = attn_scores.transpose(0, 2) tkeys = None if sent_states is not None: tkeys = self.state2key(sent_states) else: tkeys = self.state2key(outs[j]) return dec_outs, (attn_scores, word_atte_scores), new_incr_cached, tkeys
def extract_features(self, prev_output_tokens, encoder_out, incremental_state=None): """ Similar to *forward* but only return features. """ encoder_sentemb = encoder_out['sentemb'] encoder_padding_mask = encoder_out['encoder_padding_mask'] lang = encoder_out['decoder_lang'] encoder_out = encoder_out['encoder_out'] if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # get outputs from encoder encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3] srclen = encoder_outs.size(0) # embed tokens x = self.embed_tokens(prev_output_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # embed language lang_tensor = torch.LongTensor([self.lang_dictionary[lang]] * bsz).to( device=prev_output_tokens.device) l = self.embed_langs(lang_tensor) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: num_layers = len(self.layers) prev_hiddens = [encoder_sentemb for i in range(num_layers)] prev_cells = [encoder_sentemb for i in range(num_layers)] prev_hiddens = [self.encoder_hidden_proj(x) for x in prev_hiddens] prev_cells = [self.encoder_cell_proj(x) for x in prev_cells] input_feed = x.new_zeros(bsz, self.hidden_size) outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step input = torch.cat((x[j, :, :], encoder_sentemb, input_feed, l), dim=1) for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = F.dropout(hidden, p=self.dropout_out, training=self.training) # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell out = hidden out = F.dropout(out, p=self.dropout_out, training=self.training) # input feeding input_feed = out # save final output outs.append(out) # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed), ) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) # T x B x C -> B x T x C x = x.transpose(1, 0) if hasattr(self, 'additional_fc') and self.adaptive_softmax is None: x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) return x, None
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): encoder_padding_mask = encoder_out["encoder_padding_mask"] encoder_outs = encoder_out["encoder_out"] if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() srclen = encoder_outs.size(0) # embed tokens embeddings = self.embed_tokens(prev_output_tokens) x = embeddings if self.dropout is not None: x = self.dropout(x) # B x T x C -> T x B x C x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental # generation) cached_state = utils.get_incremental_state(self, incremental_state, "cached_state") if cached_state is not None: prev_hiddens, prev_cells = cached_state else: prev_hiddens = [encoder_out["encoder_out"].mean(dim=0) ] * self.num_layers prev_cells = [x.new_zeros(bsz, self.hidden_size)] * self.num_layers attn_scores = x.new_zeros(bsz, srclen) attention_outs = [] outs = [] for j in range(seqlen): input = x[j, :, :] attention_out = None for i, layer in enumerate(self.layers): # the previous state is one layer below except for the bottom # layer where the previous state is the state emitted by the # top layer hidden, cell = layer( input, ( prev_hiddens[(i - 1) % self.num_layers], prev_cells[(i - 1) % self.num_layers], ), ) if self.dropout is not None: hidden = self.dropout(hidden) prev_hiddens[i] = hidden prev_cells[i] = cell if attention_out is None: attention_out, attn_scores = self.attention( hidden, encoder_outs, encoder_padding_mask) if self.dropout is not None: attention_out = self.dropout(attention_out) attention_outs.append(attention_out) input = attention_out # collect the output of the top layer outs.append(hidden) # cache previous states (no-op except during incremental generation) utils.set_incremental_state(self, incremental_state, "cached_state", (prev_hiddens, prev_cells)) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) attention_outs_concat = torch.cat(attention_outs, dim=0).view(seqlen, bsz, self.context_dim) # T x B x C -> B x T x C x = x.transpose(0, 1) attention_outs_concat = attention_outs_concat.transpose(0, 1) # concat LSTM output, attention output and embedding # before output projection x = torch.cat((x, attention_outs_concat, embeddings), dim=2) x = self.deep_output_layer(x) x = torch.tanh(x) if self.dropout is not None: x = self.dropout(x) # project back to size of vocabulary x = self.output_projection(x) # to return the full attn_scores tensor, we need to fix the decoder # to account for subsampling input frames # return x, attn_scores return x, None
def forward(self, input_tokens, prev_scores, timestep, *inputs): """ Decoder step inputs correspond one-to-one to encoder outputs. HOWEVER: after the first step, encoder outputs (i.e, the first len(self.models) elements of inputs) must be tiled k (beam size) times on the batch dimension (axis 1). """ log_probs_per_model = [] attn_weights_per_model = [] state_outputs = [] # from flat to (batch x 1) input_tokens = input_tokens.unsqueeze(1) next_state_input = len(self.models) # size of "batch" dimension of input as tensor batch_size = torch.onnx.operators.shape_as_tensor(input_tokens)[0] # underlying assumption is each model has same vocab_reduction_module vocab_reduction_module = self.models[0].decoder.vocab_reduction_module if vocab_reduction_module is not None: possible_translation_tokens = inputs[len(self.models)] next_state_input += 1 else: possible_translation_tokens = None for i, model in enumerate(self.models): encoder_output = inputs[i] prev_hiddens = [] prev_cells = [] for _ in range(len(model.decoder.layers)): prev_hiddens.append(inputs[next_state_input]) prev_cells.append(inputs[next_state_input + 1]) next_state_input += 2 # ensure previous attention context has batch dimension input_feed_shape = torch.cat(( batch_size.view(1), torch.LongTensor([-1]), ), ) prev_input_feed = torch.onnx.operators.reshape_from_tensor_shape( inputs[next_state_input], input_feed_shape, ) next_state_input += 1 # no batching, we only care about care about "max" length src_length_int = encoder_output.size()[0] src_length = torch.LongTensor(np.array([src_length_int])) # notional, not actually used for decoder computation src_tokens = torch.LongTensor(np.array([[0] * src_length_int])) encoder_out = ( encoder_output, prev_hiddens, prev_cells, src_length, src_tokens, ) # store cached states, use evaluation mode model.decoder._is_incremental_eval = True model.eval() # placeholder incremental_state = {} # cache previous state inputs utils.set_incremental_state( model.decoder, incremental_state, 'cached_state', (prev_hiddens, prev_cells, prev_input_feed), ) decoder_output = model.decoder( input_tokens, encoder_out, incremental_state=incremental_state, possible_translation_tokens=possible_translation_tokens, ) logits, attn_scores, _ = decoder_output log_probs = F.log_softmax(logits, dim=2) log_probs_per_model.append(log_probs) attn_weights_per_model.append(attn_scores) ( next_hiddens, next_cells, next_input_feed, ) = utils.get_incremental_state( model.decoder, incremental_state, 'cached_state', ) for h, c in zip(next_hiddens, next_cells): state_outputs.extend([h, c]) state_outputs.append(next_input_feed) average_log_probs = torch.mean( torch.cat(log_probs_per_model, dim=1), dim=1, keepdim=True, ) average_attn_weights = torch.mean( torch.cat(attn_weights_per_model, dim=1), dim=1, keepdim=True, ) best_scores_k_by_k, best_tokens_k_by_k = torch.topk( average_log_probs.squeeze(1), k=self.beam_size, ) prev_scores_k_by_k = prev_scores.view(-1, 1).expand(-1, self.beam_size) total_scores_k_by_k = best_scores_k_by_k + prev_scores_k_by_k # flatten to take top k over all (beam x beam) hypos total_scores_flat = total_scores_k_by_k.view(-1) best_tokens_flat = best_tokens_k_by_k.view(-1) best_scores, best_indices = torch.topk( total_scores_flat, k=self.beam_size, ) best_tokens = best_tokens_flat.index_select( dim=0, index=best_indices, ).view(-1) # integer division to determine which input produced each successor prev_hypos = best_indices / self.beam_size attention_weights = average_attn_weights.index_select( dim=0, index=prev_hypos, ) if possible_translation_tokens is not None: best_tokens = possible_translation_tokens.index_select( dim=0, index=best_tokens, ) word_rewards_for_best_tokens = self.word_rewards.index_select( 0, best_tokens, ) best_scores += word_rewards_for_best_tokens self.input_names = ['prev_tokens', 'prev_scores', 'timestep'] for i in range(len(self.models)): self.input_names.append('fixed_input_{}'.format(i)) if possible_translation_tokens is not None: self.input_names.append('possible_translation_tokens') # 'attention_weights_average' output shape: (src_length x beam_size) attention_weights = attention_weights.squeeze(1) outputs = [ best_tokens, best_scores, prev_hypos, attention_weights, ] self.output_names = [ 'best_tokens_indices', 'best_scores', 'prev_hypos_indices', 'attention_weights_average', ] for i in range(len(self.models)): self.output_names.append('fixed_input_{}'.format(i)) if self.tile_internal: outputs.append(inputs[i].repeat(1, self.beam_size, 1)) else: outputs.append(inputs[i]) if possible_translation_tokens is not None: self.output_names.append('possible_translation_tokens') outputs.append(possible_translation_tokens) for i, state in enumerate(state_outputs): next_state = state.index_select( dim=0, index=prev_hypos, ) outputs.append(next_state) self.output_names.append('state_output_{}'.format(i)) self.input_names.append('state_input_{}'.format(i)) return tuple(outputs)
def _get_hidden_state(self, incremental_state): return utils.get_incremental_state(self, incremental_state, 'hidden_state')
def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out') if encoder_out is not None: encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out) utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
def forward( self, prev_output_tokens, encoder_out=None, incremental_state=None, possible_translation_tokens=None, timestep=None, ): (encoder_x, src_tokens, encoder_padding_mask) = encoder_out # embed positions positions = (self.embed_positions(prev_output_tokens, incremental_state=incremental_state) if self.embed_positions is not None else None) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] if self.onnx_trace: assert type(incremental_state) is list assert timestep is not None state_list = incremental_state incremental_state = {} state_index = 0 for layer in self.layers: utils.set_incremental_state( layer.avg_attn, incremental_state, "prev_vec", state_list[state_index], ) utils.set_incremental_state( layer.avg_attn, incremental_state, "prev_sum", state_list[state_index + 1], ) state_index += 2 utils.set_incremental_state(layer.avg_attn, incremental_state, "prev_pos", timestep.float()) if layer.encoder_attn is not None: utils.set_incremental_state( layer.encoder_attn, incremental_state, "prev_key", state_list[state_index], ) utils.set_incremental_state( layer.encoder_attn, incremental_state, "prev_value", state_list[state_index + 1], ) state_index += 2 # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) attn = None inner_states = [x] # decoder layers for layer in self.layers: x, attn = layer( x, encoder_x, encoder_padding_mask, incremental_state, self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, ) inner_states.append(x) if self.normalize: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) # project back to size of vocabulary if self.share_input_output_embed: output_weights = self.embed_tokens.weight else: output_weights = self.embed_out if (self.vocab_reduction_module is not None and possible_translation_tokens is None): decoder_input_tokens = prev_output_tokens.contiguous() possible_translation_tokens = self.vocab_reduction_module( src_tokens, decoder_input_tokens=decoder_input_tokens) if possible_translation_tokens is not None: output_weights = output_weights.index_select( dim=0, index=possible_translation_tokens) if self.adaptive_softmax is None: logits = F.linear(x, output_weights) else: assert ( possible_translation_tokens is None ), "vocabulary reduction and adaptive softmax are incompatible!" logits = x if self.onnx_trace: state_outputs = [] for layer in self.layers: prev_vec = utils.get_incremental_state(layer.avg_attn, incremental_state, "prev_vec") prev_sum = utils.get_incremental_state(layer.avg_attn, incremental_state, "prev_sum") state_outputs.extend([prev_vec, prev_sum]) if layer.encoder_attn is not None: prev_key = utils.get_incremental_state( layer.encoder_attn, incremental_state, "prev_key") prev_value = utils.get_incremental_state( layer.encoder_attn, incremental_state, "prev_value") state_outputs.extend([prev_key, prev_value]) return logits, attn, possible_translation_tokens, state_outputs return logits, attn, possible_translation_tokens
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): #def forward(self, prev_output_tokens, encoder_out, incremental_state=None): encoder_padding_mask = encoder_out['encoder_padding_mask'] encoder_out = encoder_out['encoder_out'] if incremental_state is not None: print(prev_output_tokens.size()) # prev_output_tokens = prev_output_tokens[:, -1:] prev_output_tokens = prev_output_tokens[:, -1:, :] # bsz, one_input_size = prev_output_tokens.size() # self.seq_len = 360 # seqlen = self.seq_len bsz, seqlen, segment_units = prev_output_tokens.size() # get outputs from encoder encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3] srclen = encoder_outs.size(0) # embed tokens # x = self.embed_tokens(prev_output_tokens) x = prev_output_tokens.view(-1, self.seq_len, self.input_size).float() # print(x.size()) # x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: num_layers = len(self.layers) prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)] if self.encoder_hidden_proj is not None: prev_hiddens = [ self.encoder_hidden_proj(x) for x in prev_hiddens ] prev_cells = [self.encoder_cell_proj(x) for x in prev_cells] input_feed = x.new_ones(bsz, self.hidden_size) * encoder_outs[ -1, :bsz, :self.hidden_size] #[0.5,0.1,1.0,0.0,0.0]#0.5 input_feed = nn.functional.relu(input_feed) attn_scores = x.new_zeros( srclen, seqlen, bsz ) #x.new_zeros(segment_units, seqlen, bsz) #x.new_zeros(srclen, seqlen, bsz) outs = [] boundry_param_list = [] segment_param_list = [] for j in range(seqlen): # from fairseq import pdb; pdb.set_trace() # input feeding: concatenate context vector from previous time step input_d = F.dropout(x[j, :, :], p=0.5, training=self.training) input_mask = input_d > 1e-6 #0.#-1e-6 input_in = (x[j, :, :] * input_mask.float()) + ( (1 - input_mask.float()) * input_feed) #input = torch.clamp(input, min=-1.0, max=1.0) #import pdb; pdb.set_trace() self.print_count += 1 if self.print_count % 1000 == 0: #random.random() > 0.9999: #from fairseq import pdb; pdb.set_trace() means = (input_in * (self.max_vals + 1e-6)).view( -1, 18, 5).mean(dim=1).cpu().detach().numpy() print("\n\ninput means\t", means) wandb.log({"input0": wandb.Histogram(means[:, 0])}) wandb.log({"input1": wandb.Histogram(means[:, 1])}) wandb.log({"input2": wandb.Histogram(means[:, 2])}) wandb.log({"input3": wandb.Histogram(means[:, 3])}) #wandb.log({"input4": wandb.Histogram(means[4])}) mean_x = x[j, :, :].view(-1, 18, 5).mean(dim=1) print("x[j, :, :] means\t", mean_x.cpu().detach().numpy()) mean_feed = input_feed.view(-1, 18, 5).mean(dim=1) print("input_feed means\t", mean_feed.cpu().detach().numpy()) # if random.random()>0.0: # input = x[j, :, :]#torch.cat((x[j, :, :], input_feed), dim=1) # else: # input = input_feed input = input_in for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer #input = F.dropout(hidden, p=self.dropout_out, training=self.training) # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # apply attention using the last layer's hidden state if self.attention is not None: out, attn_scores[:, j, :] = self.attention( hidden, encoder_outs, encoder_padding_mask) else: out = hidden # from fairseq import pdb; pdb.set_trace() ntf_input = self.ntf_projection(out) boundry_params, segment_params = torch.split( ntf_input, [3, 8 * self.num_segments], dim=1) segment_params = segment_params.view((-1, 8, self.num_segments)) boundry_param_list.append(boundry_params) segment_param_list.append(segment_params) # boundry_params = torch.Tensor([200.0,10000.0,200.0]).to(self.device)*torch.sigmoid(boundry_params) boundry_params = torch.Tensor([150.0, 10000.0, 100.0]).to( self.device) * torch.sigmoid(boundry_params) segment_params = torch.cat([ torch.sigmoid(segment_params[:, :4, :]), torch.tanh(segment_params[:, 4:, :]) ], dim=1) # vf, a, rhocr, g, omegar, omegas, epsq, epsv segment_params = segment_params * torch.Tensor([[150.0], [ 2.0 ], [100.0], [5.0], [10.0], [10.0], [10.0], [10.0]]).to(self.device) segment_params = segment_params.permute(0, 2, 1) unscaled_input = input_in * self.max_vals # print("boundry_params",boundry_params[0,::5].mean().item(),boundry_params.size()) # print("segment_params",segment_params[0,::5,0].mean().item(),segment_params.size()) # print(unscaled_input) model_steps = [] num_steps = 3 #18 for _ in range(num_steps): out1 = self.ntf_module(unscaled_input, segment_params, boundry_params) model_steps.append(out1) unscaled_input = out1 out = torch.stack(model_steps, dim=0).mean(dim=0) avg_sum_max_vals = self.max_vals # summing everything above but speed and occ need to be avg # avg_sum_max_vals[1::5] *= num_steps #mean occupancy # avg_sum_max_vals[2::5] *= num_steps #mean speed out = out / (avg_sum_max_vals + 1e-6) # print(out.mean().item()) # out = out / (self.max_vals+1e-6) #from fairseq import pdb; pdb.set_trace() #out = F.dropout(out, p=self.dropout_out, training=self.training) # input feeding input_feed = out #.view(-1,360,90) # save final output outs.append(out) # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed), ) # collect outputs across time steps #print(torch.stack(outs, dim=0).size()) # from fairseq import pdb; pdb.set_trace(); x = torch.stack(outs, dim=1) #.view(seqlen, bsz, self.hidden_size) self.all_boundry_params = torch.stack(boundry_param_list, dim=1) self.all_segment_params = torch.stack(segment_param_list, dim=1) # T x B x C -> B x T x C #x = x.transpose(1, 0) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen if not self.training and self.need_attn: attn_scores = attn_scores.transpose(0, 2) else: attn_scores = None # project back to size of vocabulary if self.adaptive_softmax is None: if hasattr(self, 'additional_fc'): x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) if self.share_input_output_embed: x = F.linear(x, self.embed_tokens.weight) # else: # x = self.fc_out(x) # import fairseq.pdb as pdb; pdb.set_trace()#[:,-1,:] #x = x.contiguous().view(bsz,-1)#self.output_size)#self.fc_out(x) return x, attn_scores
def _get_input_buffer( self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): return utils.get_incremental_state(self, incremental_state, "input_buffer")
def _get_cached_bert(self, incremental_state): return utils.get_incremental_state( self, incremental_state, 'cached_bert', )
def extract_features(self, prev_output_tokens, encoder_out, incremental_state=None): """ Similar to *forward* but only return features. """ if encoder_out is not None: encoder_padding_mask = encoder_out['encoder_padding_mask'] encoder_out = encoder_out['encoder_out'] else: encoder_padding_mask = None encoder_out = None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # get outputs from encoder if encoder_out is not None: encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3] srclen = encoder_outs.size(0) else: srclen = None # embed tokens x = self.embed_tokens(prev_output_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state elif encoder_out is not None: # setup recurrent cells num_layers = len(self.layers) prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)] if self.encoder_hidden_proj is not None: prev_hiddens = [ self.encoder_hidden_proj(x) for x in prev_hiddens ] prev_cells = [self.encoder_cell_proj(x) for x in prev_cells] input_feed = x.new_zeros(bsz, self.hidden_size) else: # setup zero cells, since there is no encoder num_layers = len(self.layers) zero_state = x.new_zeros(bsz, self.hidden_size) prev_hiddens = [zero_state for i in range(num_layers)] prev_cells = [zero_state for i in range(num_layers)] input_feed = None assert srclen is not None or self.attention is None, \ "attention is not supported if there are no encoder outputs" attn_scores = x.new_zeros(srclen, seqlen, bsz) if self.attention is not None else None outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step if input_feed is not None: input = torch.cat((x[j, :, :], input_feed), dim=1) else: input = x[j] for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = F.dropout(hidden, p=self.dropout_out, training=self.training) # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # apply attention using the last layer's hidden state if self.attention is not None: out, attn_scores[:, j, :] = self.attention( hidden, encoder_outs, encoder_padding_mask) else: out = hidden out = F.dropout(out, p=self.dropout_out, training=self.training) # input feeding if input_feed is not None: input_feed = out # save final output outs.append(out) # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed), ) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) # T x B x C -> B x T x C x = x.transpose(1, 0) if hasattr(self, 'additional_fc') and self.adaptive_softmax is None: x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen if not self.training and self.need_attn and self.attention is not None: attn_scores = attn_scores.transpose(0, 2) else: attn_scores = None return x, attn_scores
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): if encoder_out_dict is not None: encoder_out = encoder_out_dict['encoder_out'] encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # get outputs from encoder encoder_outs, _, _ = encoder_out[:3] srclen = encoder_outs.size(0) if bsz != encoder_outs.size(1): prev_output_tokens = prev_output_tokens.t() bsz, seqlen = seqlen, bsz # embed tokens x = self.embed_tokens(prev_output_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) x_in = F.dropout(x, p=self.dropout_in, training=self.training) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells = cached_state else: _, encoder_hiddens, encoder_cells = encoder_out[:3] if self.initial_state == 'same': prev_hiddens = encoder_hiddens prev_cells = encoder_cells elif self.initial_state == 'linear': prev_hiddens = self.tanh(self.proj_hidden(encoder_hiddens)) prev_cells = self.tanh(self.proj_cell(encoder_cells)) else: raise NotImplementedError() attn_scores = x.data.new(srclen, seqlen, bsz).zero_() outs = [] ctxs = [] if hasattr(self, 'ctx_proj'): encoder_ctx = self.ctx_proj(encoder_outs) else: encoder_ctx = encoder_outs for j in range(seqlen): input = x_in[j] for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens, prev_cells)) # apply attention using the last layer's hidden state if self.attention is not None and i == 0: # attention output becomes the input to the next layer attn_input = F.dropout(hidden, p=self.dropout_out, training=self.training) ctx, attn_scores[:, j, :] = self.attention(attn_input, encoder_ctx, encoder_padding_mask) ctxs.append(ctx) input = F.dropout(ctx, p=self.dropout_out, training=self.training) else: out = hidden # save state for next time step prev_hiddens = hidden prev_cells = cell # save final output outs.append(out) # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, 'cached_state', (prev_hiddens, prev_cells)) # collect outputs across time steps outs = torch.stack(outs) ctxs = torch.stack(ctxs) out = torch.cat([outs, ctxs, x], dim=2) out = F.dropout(out, p=self.dropout_out, training=self.training) # T x B x C -> B x T x C out = out.transpose(1, 0) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen if not self.training and self.need_attn: attn_scores = attn_scores.transpose(0, 2) else: attn_scores = None # project back to size of vocabulary out = self.tanh(self.additional_fc(out)) out = F.dropout(out, p=self.dropout_out, training=self.training) out = self.fc_out(out) return out, attn_scores
def _get_input_buffer(self, incremental_state, incremental_clone_id=""): return (utils.get_incremental_state( self, incremental_state, "attn_state" + incremental_clone_id) or {})
def _get_monotonic_buffer(self, incremental_state): return utils.get_incremental_state( self, incremental_state, 'monotonic', ) or {}
def get_pointer(self, incremental_state): return utils.get_incremental_state( self, incremental_state, 'monotonic', ) or {}
def forward(self, input_token, timestep, *inputs): """ Decoder step inputs correspond one-to-one to encoder outputs. """ log_probs_per_model = [] attn_weights_per_model = [] state_outputs = [] next_state_input = len(self.models) # underlying assumption is each model has same vocab_reduction_module vocab_reduction_module = self.models[0].decoder.vocab_reduction_module if vocab_reduction_module is not None: possible_translation_tokens = inputs[len(self.models)] next_state_input += 1 else: possible_translation_tokens = None for i, model in enumerate(self.models): encoder_output = inputs[i] prev_hiddens = [] prev_cells = [] for _ in range(len(model.decoder.layers)): prev_hiddens.append(inputs[next_state_input]) prev_cells.append(inputs[next_state_input + 1]) next_state_input += 2 prev_input_feed = inputs[next_state_input].view((1, -1)) next_state_input += 1 # no batching, we only care about care about "max" length src_length_int = int(encoder_output.size()[0]) src_length = torch.LongTensor(np.array([src_length_int])) # notional, not actually used for decoder computation src_tokens = torch.LongTensor(np.array([[0] * src_length_int])) encoder_out = ( encoder_output, prev_hiddens, prev_cells, src_length, src_tokens, ) # store cached states, use evaluation mode model.decoder._is_incremental_eval = True model.eval() # placeholder incremental_state = {} # cache previous state inputs utils.set_incremental_state( model.decoder, incremental_state, "cached_state", (prev_hiddens, prev_cells, prev_input_feed), ) decoder_output = model.decoder( input_token, encoder_out, incremental_state=incremental_state, possible_translation_tokens=possible_translation_tokens, ) logits, attn_scores, _ = decoder_output log_probs = F.log_softmax(logits, dim=2) log_probs_per_model.append(log_probs) attn_weights_per_model.append(attn_scores) (next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state( model.decoder, incremental_state, "cached_state") for h, c in zip(next_hiddens, next_cells): state_outputs.extend([h, c]) state_outputs.append(next_input_feed) average_log_probs = torch.mean(torch.cat(log_probs_per_model, dim=0), dim=0, keepdim=True) average_attn_weights = torch.mean(torch.cat(attn_weights_per_model, dim=0), dim=0, keepdim=True) best_scores, best_tokens = torch.topk(average_log_probs.view(1, -1), k=self.beam_size) if possible_translation_tokens is not None: best_tokens = possible_translation_tokens.index_select( dim=0, index=best_tokens.view(-1)).view(1, -1) word_rewards_for_best_tokens = self.word_rewards.index_select( 0, best_tokens.view(-1)) best_scores += word_rewards_for_best_tokens self.input_names = ["prev_token", "timestep"] for i in range(len(self.models)): self.input_names.append(f"fixed_input_{i}") if possible_translation_tokens is not None: self.input_names.append("possible_translation_tokens") outputs = [best_tokens, best_scores, average_attn_weights] self.output_names = [ "best_tokens_indices", "best_scores", "attention_weights_average", ] for i, state in enumerate(state_outputs): outputs.append(state) self.output_names.append(f"state_output_{i}") self.input_names.append(f"state_input_{i}") return tuple(outputs)
def forward(self, input_tokens, prev_scores, timestep, *inputs): """ Decoder step inputs correspond one-to-one to encoder outputs. HOWEVER: after the first step, encoder outputs (i.e, the first len(self.models) elements of inputs) must be tiled k (beam size) times on the batch dimension (axis 1). """ log_probs_per_model = [] attn_weights_per_model = [] state_outputs = [] beam_axis_per_state = [] # from flat to (batch x 1) input_tokens = input_tokens.unsqueeze(1) next_state_input = len(self.models) # size of "batch" dimension of input as tensor batch_size = torch.onnx.operators.shape_as_tensor(input_tokens)[0] # underlying assumption is each model has same vocab_reduction_module vocab_reduction_module = self.models[0].decoder.vocab_reduction_module if vocab_reduction_module is not None: possible_translation_tokens = inputs[len(self.models)] next_state_input += 1 else: possible_translation_tokens = None for i, model in enumerate(self.models): if (isinstance(model, rnn.RNNModel) or isinstance(model, char_source_model.CharSourceModel) or isinstance(model, word_prediction_model.WordPredictionModel)): encoder_output = inputs[i] prev_hiddens = [] prev_cells = [] for _ in range(len(model.decoder.layers)): prev_hiddens.append(inputs[next_state_input]) prev_cells.append(inputs[next_state_input + 1]) next_state_input += 2 # ensure previous attention context has batch dimension input_feed_shape = torch.cat( (batch_size.view(1), torch.LongTensor([-1]))) prev_input_feed = torch.onnx.operators.reshape_from_tensor_shape( inputs[next_state_input], input_feed_shape) next_state_input += 1 # no batching, we only care about care about "max" length src_length_int = int(encoder_output.size()[0]) src_length = torch.LongTensor(np.array([src_length_int])) # notional, not actually used for decoder computation src_tokens = torch.LongTensor(np.array([[0] * src_length_int])) src_embeddings = encoder_output.new_zeros(encoder_output.shape) encoder_out = ( encoder_output, prev_hiddens, prev_cells, src_length, src_tokens, src_embeddings, ) # store cached states, use evaluation mode model.decoder._is_incremental_eval = True model.eval() # placeholder incremental_state = {} # cache previous state inputs utils.set_incremental_state( model.decoder, incremental_state, "cached_state", (prev_hiddens, prev_cells, prev_input_feed), ) decoder_output = model.decoder( input_tokens, encoder_out, incremental_state=incremental_state, possible_translation_tokens=possible_translation_tokens, ) logits, attn_scores, _ = decoder_output log_probs = F.log_softmax(logits, dim=2) log_probs_per_model.append(log_probs) attn_weights_per_model.append(attn_scores) ( next_hiddens, next_cells, next_input_feed, ) = utils.get_incremental_state(model.decoder, incremental_state, "cached_state") for h, c in zip(next_hiddens, next_cells): state_outputs.extend([h, c]) beam_axis_per_state.extend([0, 0]) state_outputs.append(next_input_feed) beam_axis_per_state.append(0) elif isinstance(model, transformer.TransformerModel): encoder_output = inputs[i] # store cached states, use evaluation mode model.decoder._is_incremental_eval = True model.eval() # placeholder incremental_state = {} state_inputs = [] for _ in model.decoder.layers: # (prev_key, prev_value) for self- and encoder-attention state_inputs.extend( inputs[next_state_input:next_state_input + 4]) next_state_input += 4 encoder_out = (encoder_output, None, None) decoder_output = model.decoder( input_tokens, encoder_out, incremental_state=state_inputs, possible_translation_tokens=possible_translation_tokens, timestep=timestep, ) logits, attn_scores, _, attention_states = decoder_output log_probs = F.log_softmax(logits, dim=2) log_probs_per_model.append(log_probs) attn_weights_per_model.append(attn_scores) state_outputs.extend(attention_states) beam_axis_per_state.extend([0 for _ in attention_states]) else: raise RuntimeError(f"Not a supported model: {type(model)}") average_log_probs = torch.mean(torch.cat(log_probs_per_model, dim=1), dim=1, keepdim=True) average_attn_weights = torch.mean(torch.cat(attn_weights_per_model, dim=1), dim=1, keepdim=True) best_scores_k_by_k, best_tokens_k_by_k = torch.topk( average_log_probs.squeeze(1), k=self.beam_size) prev_scores_k_by_k = prev_scores.view(-1, 1).expand(-1, self.beam_size) total_scores_k_by_k = best_scores_k_by_k + prev_scores_k_by_k # flatten to take top k over all (beam x beam) hypos total_scores_flat = total_scores_k_by_k.view(-1) best_tokens_flat = best_tokens_k_by_k.view(-1) best_scores, best_indices = torch.topk(total_scores_flat, k=self.beam_size) best_tokens = best_tokens_flat.index_select( dim=0, index=best_indices).view(-1) # integer division to determine which input produced each successor prev_hypos = best_indices / self.beam_size attention_weights = average_attn_weights.index_select(dim=0, index=prev_hypos) if possible_translation_tokens is not None: best_tokens = possible_translation_tokens.index_select( dim=0, index=best_tokens) word_rewards_for_best_tokens = self.word_rewards.index_select( 0, best_tokens) best_scores += word_rewards_for_best_tokens self.input_names = ["prev_tokens", "prev_scores", "timestep"] for i in range(len(self.models)): self.input_names.append(f"fixed_input_{i}") if possible_translation_tokens is not None: self.input_names.append("possible_translation_tokens") # 'attention_weights_average' output shape: (src_length x beam_size) attention_weights = attention_weights.squeeze(1) outputs = [best_tokens, best_scores, prev_hypos, attention_weights] self.output_names = [ "best_tokens_indices", "best_scores", "prev_hypos_indices", "attention_weights_average", ] for i in range(len(self.models)): self.output_names.append(f"fixed_input_{i}") if self.tile_internal: outputs.append(inputs[i].repeat(1, self.beam_size, 1)) else: outputs.append(inputs[i]) if possible_translation_tokens is not None: self.output_names.append("possible_translation_tokens") outputs.append(possible_translation_tokens) for i, state in enumerate(state_outputs): beam_axis = beam_axis_per_state[i] next_state = state.index_select(dim=beam_axis, index=prev_hypos) outputs.append(next_state) self.output_names.append(f"state_output_{i}") self.input_names.append(f"state_input_{i}") return tuple(outputs)
def forward_unprojected(self, input_tokens, encoder_out, incremental_state=None): if incremental_state is not None: input_tokens = input_tokens[:, -1:] bsz, seqlen = input_tokens.size() # get outputs from encoder (encoder_outs, final_hidden, final_cell, src_lengths, src_tokens) = encoder_out # embed tokens x = self.embed_tokens(input_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state( self, incremental_state, "cached_state" ) input_feed = None if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: # first time step, initialize previous states prev_hiddens, prev_cells = self._init_prev_states(encoder_out) if self.attention.context_dim: input_feed = self.initial_attn_context.expand( bsz, self.attention.context_dim ) attn_scores_per_step = [] outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step step_input = maybe_cat((x[j, :, :], input_feed), dim=1) previous_layer_input = step_input for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(step_input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer layer_output = F.dropout( hidden, p=self.dropout_out, training=self.training ) if self.residual_level is not None and i >= self.residual_level: # TODO add an assert related to sizes here step_input = layer_output + previous_layer_input else: step_input = layer_output previous_layer_input = step_input # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell out, step_attn_scores = self.attention(hidden, encoder_outs, src_lengths) input_feed = out attn_scores_per_step.append(step_attn_scores.unsqueeze(1)) attn_scores = torch.cat(attn_scores_per_step, dim=1) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen attn_scores = attn_scores.transpose(0, 2) combined_output_and_context = maybe_cat((hidden, out), dim=1) # save final output outs.append(combined_output_and_context) # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, "cached_state", (prev_hiddens, prev_cells, input_feed), ) # collect outputs across time steps x = torch.cat(outs, dim=0).view( seqlen, bsz, self.combined_output_and_context_dim ) # T x B x C -> B x T x C x = x.transpose(1, 0) # bottleneck layer if hasattr(self, "additional_fc"): x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) return x, attn_scores
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): encoder_out = encoder_out_dict['encoder_out'] encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # get outputs from encoder encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3] srclen = encoder_outs.size(0) # embed tokens x = self.embed_tokens(prev_output_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: num_layers = len(self.layers) prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)] if self.encoder_hidden_proj is not None: prev_hiddens = [ self.encoder_hidden_proj(x) for x in prev_hiddens ] prev_cells = [self.encoder_cell_proj(x) for x in prev_cells] input_feed = x.new_zeros(bsz, self.hidden_size) attn_scores = x.new_zeros(srclen, seqlen, bsz) outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step input = torch.cat((x[j, :, :], input_feed), dim=1) for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = F.dropout(hidden, p=self.dropout_out, training=self.training) # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # apply attention using the last layer's hidden state if self.attention is not None: out, attn_scores[:, j, :] = self.attention( hidden, encoder_outs, encoder_padding_mask) else: out = hidden out = F.dropout(out, p=self.dropout_out, training=self.training) # input feeding input_feed = out # save final output outs.append(out) # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed), ) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) # T x B x C -> B x T x C x = x.transpose(1, 0) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen if not self.training and self.need_attn: attn_scores = attn_scores.transpose(0, 2) else: attn_scores = None # project back to size of vocabulary if self.adaptive_softmax is None: if hasattr(self, 'additional_fc'): x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) if self.share_input_output_embed: x = F.linear(x, self.embed_tokens.weight) else: x = self.fc_out(x) return x, attn_scores
def forward(self, prev_output_tokens, encoder_out, lang, incremental_state=None): encoder_sentemb = encoder_out['sentemb'] encoder_padding_mask = encoder_out['encoder_padding_mask'] encoder_out = encoder_out['encoder_out'] if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # get outputs from encoder encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3] # embed tokens x = self.embed_tokens(prev_output_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # embed language lang_tensor = torch.cuda.LongTensor([self.lang_dictionary[lang]] * bsz) l = self.embed_langs(lang_tensor) # B x T x C -> T x B x C #l = l.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state print(len(prev_cells[0])) else: num_layers = len(self.layers) # Hiddens and cells are initialized with a linear transformation of the embedding produced by the encoder prev_hiddens = [encoder_sentemb for i in range(num_layers)] prev_cells = [encoder_sentemb for i in range(num_layers)] prev_hiddens = [self.encoder_hidden_proj(x) for x in prev_hiddens] prev_cells = [self.encoder_cell_proj(x) for x in prev_cells] input_feed = x.new_zeros(bsz, self.hidden_size) outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step input = torch.cat((x[j, :, :], encoder_sentemb, input_feed, l), dim=1) for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = F.dropout(hidden, p=self.dropout_out, training=self.training) # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell out = hidden out = F.dropout(out, p=self.dropout_out, training=self.training) # input feeding input_feed = out # save final output outs.append(out) # cache previous states (no-op except during incremental generation) utils.set_incremental_state( self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed), ) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) # T x B x C -> B x T x C x = x.transpose(1, 0) # project back to size of vocabulary if self.adaptive_softmax is None: if hasattr(self, 'additional_fc'): x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) if self.share_input_output_embed: x = F.linear(x, self.embed_tokens.weight) else: x = self.fc_out(x) return x, None
def _get_input_buffer(self, incremental_state): return utils.get_incremental_state(self, incremental_state, 'input_buffer')
def _get_input_buffer(self, incremental_state): return utils.get_incremental_state( self, incremental_state, 'attn_state', ) or {}