def _set_input_buffer(self, incremental_state, buffer): utils.set_incremental_state( self, incremental_state, 'f', buffer, )
def _set_input_buffer(self, incremental_state, buffer): set_incremental_state( self, incremental_state, 'attn_state', buffer, )
def forward(self, x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, **kwargs): layer_norm_training = kwargs.get('layer_norm_training', None) if layer_norm_training is not None: self.layer_norm1.training = layer_norm_training self.layer_norm2.training = layer_norm_training self.lstm.flatten_parameters() if incremental_state is not None: x = x[-1:, :, :] cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells = cached_state else: prev_hiddens = encoder_out.mean(dim=0, keepdim=True) prev_cells = encoder_out.mean(dim=0, keepdim=True) residual = x x = self.layer_norm1(x) x, hidden = self.lstm(x, (prev_hiddens, prev_cells)) hiddens, cells = hidden x = residual + x x = self.layer_norm2(x) x, attn = self.attention( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, enc_dec_attn_constraint_mask=utils.get_incremental_state( self, incremental_state, 'enc_dec_attn_constraint_mask')) x = F.dropout(x, self.dropout, training=self.training) if incremental_state is not None: #prev_hiddens = torch.cat((prev_hiddens, hiddens), dim=0) #prev_cells = torch.cat((prev_cells, cells), dim=0) prev_hiddens = hiddens prev_cells = cells utils.set_incremental_state( self, incremental_state, 'cached_state', (prev_hiddens, prev_cells), ) x = residual + x attn_logits = attn[1] #if len(attn_logits.size()) > 3: # attn_logits = attn_logits[:, 0] return x, attn_logits
def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None): if incremental_state is not None: prev_hiddens = encoder_out.mean(dim=0, keepdim=True) prev_cells = encoder_out.mean(dim=0, keepdim=True) utils.set_incremental_state(self, incremental_state, 'cached_state', (prev_hiddens, prev_cells))
def reorder_incremental_state(self, incremental_state, new_order): cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is None: return def reorder_state(state): if isinstance(state, list): return [reorder_state(state_i) for state_i in state] return state.index_select(0, new_order) if not isinstance(new_order, Variable): new_order = Variable(new_order) new_state = tuple(map(reorder_state, cached_state)) utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # get outputs from encoder encoder_outs, _, _ = encoder_out srclen = encoder_outs.size(0) x = self.embed_tokens(prev_output_tokens) # (bze, seqlen, embed_dim) x = F.dropout(x, p=self.dropout_in, training=self.training) embed_dim = x.size(2) x = x.transpose(0, 1) # (seqlen, bsz, embed_dim) # initialize previous states (or get from cache during incremental generation) # cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: _, encoder_hiddens, encoder_cells = encoder_out num_layers = len(self.layers) prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)] input_feed = Variable(x.data.new(bsz, embed_dim).zero_()) attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_()) outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step input = torch.cat((x[j, :, :], input_feed), dim=1) for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = F.dropout(hidden, p=self.dropout_out, training=self.training) # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # apply attention using the last layer's hidden state out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs) out = F.dropout(out, p=self.dropout_out, training=self.training) # input feeding input_feed = out # save final output outs.append(out) # cache previous states (no-op except during incremental generation) utils.set_incremental_state(self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed)) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim) # T x B x C -> B x T x C x = x.transpose(1, 0) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen attn_scores = attn_scores.transpose(0, 2) x = self.fc_out(x) return x, attn_scores
def _set_input_buffer(self, incremental_state, new_buffer): return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
def set_buffer(self, name, tensor, incremental_state): return utils.set_incremental_state(self, incremental_state, name, tensor)
def forward( self, phase, epoch, fixed_max_len, prev_output_tokens, encoder_out, incremental_state=None, ): if phase == 'MLE': if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # print("generator.py LSTMDecoder forward", seqlen) # get outputs from encoder encoder_outs, _, _ = encoder_out srclen = encoder_outs.size(0) x1 = self.embed_tokens( prev_output_tokens) # (bze, seqlen, embed_dim) x2 = F.dropout(x1, p=self.dropout_in, training=self.training) embed_dim = x2.size(2) x3 = x2.transpose(0, 1) # (seqlen, bsz, embed_dim) x = x3.detach() # initialize previous states (or get from cache during incremental generation) # cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: _, encoder_hiddens, encoder_cells = encoder_out num_layers = len(self.layers) prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)] input_feed = x.data.new(bsz, embed_dim).zero_() attn_scores = x.data.new(srclen, seqlen, bsz).zero_() outs = [] p_list = [] for j in range(fixed_max_len): # input feeding: concatenate context vector from previous time step # teacher forcing # input_feed 是decoder hidden结合encoder output的attention向量 # x 是input (prev_output_tokens)长度 # print('11111111111111111111111', x.size(),input_feed.size()) input = torch.cat((x[j, :, :], input_feed), dim=1) for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = F.dropout(hidden, p=self.dropout_out, training=self.training) # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell decoder_hidden = hidden # apply attention using the last layer's hidden state out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs) out = F.dropout(out, p=self.dropout_out, training=self.training) # input feeding input_feed = out # save final output out_1 = out.unsqueeze(0) out_2 = out_1.transpose(1, 0) out_3 = self.fc_out(out_2) # out_3 = [batch,1, num_vocab] outs.append(out_3) word = torch.argmax(out_3, dim=-1) # word = [batch,1] out_4 = self.embed_tokens(word).squeeze( 1) # word = [batch,dim] if j < fixed_max_len - 1: p = self.calculate_p(epoch, x[j + 1, :, :], out_4) is_teacher = random.random() > p if not is_teacher: x[j + 1, :, :] = out_4[:, :] # cache previous states (no-op except during incremental generation) utils.set_incremental_state(self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed)) attn_scores = attn_scores.transpose(0, 2) x = torch.cat(outs, dim=1).view(bsz, seqlen, -1) # x = [batch,len,num_vocab] return x, attn_scores, p elif phase == 'PG': if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # print("generator.py LSTMDecoder forward", seqlen) # get outputs from encoder encoder_outs, _, _ = encoder_out srclen = encoder_outs.size(0) x = self.embed_tokens( prev_output_tokens) # (bze, seqlen, embed_dim) x = F.dropout(x, p=self.dropout_in, training=self.training) embed_dim = x.size(2) x = x.transpose(0, 1) # (seqlen, bsz, embed_dim) # initialize previous states (or get from cache during incremental generation) # cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: _, encoder_hiddens, encoder_cells = encoder_out num_layers = len(self.layers) prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)] input_feed = x.data.new(bsz, embed_dim).zero_() attn_scores = x.data.new(srclen, seqlen, bsz).zero_() outs = [] for j in range(fixed_max_len): # input feeding: concatenate context vector from previous time step # teacher forcing # input_feed 是decoder hidden结合encoder output的attention向量 # x 是input (prev_output_tokens)长度 # print('11111111111111111111111', x.size(),input_feed.size()) input = torch.cat((x[j, :, :], input_feed), dim=1) for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = F.dropout(hidden, p=self.dropout_out, training=self.training) # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # apply attention using the last layer's hidden state out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs) out = F.dropout(out, p=self.dropout_out, training=self.training) # input feeding input_feed = out # save final output outs.append(out) if j < fixed_max_len - 1: x[j + 1, :, :] = input[:, :] # cache previous states (no-op except during incremental generation) utils.set_incremental_state(self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed)) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim) # T x B x C -> B x T x C x = x.transpose(1, 0) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen attn_scores = attn_scores.transpose(0, 2) x = self.fc_out(x) p = 0 return x, attn_scores, p elif phase == 'test': if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() # get outputs from encoder encoder_outs, _, _ = encoder_out srclen = encoder_outs.size(0) x = self.embed_tokens( prev_output_tokens) # (bze, seqlen, embed_dim) x = F.dropout(x, p=self.dropout_in, training=self.training) embed_dim = x.size(2) x = x.transpose(0, 1) # (seqlen, bsz, embed_dim) # initialize previous states (or get from cache during incremental generation) # cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') # initialize previous states (or get from cache during incremental generation) cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: _, encoder_hiddens, encoder_cells = encoder_out num_layers = len(self.layers) prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)] input_feed = x.data.new(bsz, embed_dim).zero_() attn_scores = x.data.new(srclen, seqlen, bsz).zero_() outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step # teacher forcing # input_feed 是decoder hidden结合encoder output的attention向量 # x 是input (prev_output_tokens)长度 input = torch.cat((x[j, :, :], input_feed), dim=1) for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # hidden state becomes the input to the next layer input = F.dropout(hidden, p=self.dropout_out, training=self.training) # save state for next time step prev_hiddens[i] = hidden prev_cells[i] = cell # apply attention using the last layer's hidden state out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs) out = F.dropout(out, p=self.dropout_out, training=self.training) # input feeding input_feed = out # save final output outs.append(out) # cache previous states (no-op except during incremental generation) utils.set_incremental_state(self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed)) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim) # T x B x C -> B x T x C x = x.transpose(1, 0) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen attn_scores = attn_scores.transpose(0, 2) x = self.fc_out(x) p = 0 return x, attn_scores, p