def forward(self, input_embeds_list): """ Args: input_embeds_list (list[SequenceBatchElement]): where each element is of shape (batch_size, input_dim) Returns: hidden_states_list (list[SequenceBatchElement]) where each element is (batch_size, hidden_dim) """ batch_size = input_embeds_list[0].values.size()[0] h = tile_state(self.h0, batch_size) # (batch_size, hidden_dim) c = tile_state(self.c0, batch_size) # (batch_size, hidden_dim) hidden_states_list = [] for t, x in enumerate(input_embeds_list): # x.values has shape (batch_size, input_dim) # x.mask has shape (batch_size, 1) h_new, c_new = self.rnn_cell(x.values, (h, c)) h = gated_update(h, h_new, x.mask) c = gated_update(c, c_new, x.mask) hidden_states_list.append( SequenceBatchElement(self.dropout(h), x.mask)) return hidden_states_list
def forward(self, rnn_state, rnn_input, advance): x = torch.cat([rnn_input.x, rnn_input.agenda], 1) hs, cs = [], [] for layer in range(self.num_layers): rnn_cell = self.rnn_cells[layer] # collect the h, c belonging to the previous time-step at the corresponding depth h_prev_t, c_prev_t = rnn_state.hs[layer], rnn_state.cs[layer] # forward pass and masking h, c = rnn_cell(x, (h_prev_t, c_prev_t)) h = gated_update(h_prev_t, h, advance) c = gated_update(c_prev_t, c, advance) hs.append(h) cs.append(c) if layer == 0: x = h # no skip connection on the first layer else: x = x + h query = self.linear(x) word_vocab = self.token_embedder.vocab word_embeds = self.token_embedder.embeds vocab_logits = torch.mm(query, word_embeds.t()) # (batch_size, vocab_size) vocab_probs = self.softmax(vocab_logits) rnn_state = MultilayeredRNNState(hs, cs) return DecoderCellOutput(rnn_state, vocab=word_vocab, vocab_probs=vocab_probs)
def forward(self, rnn_state, decoder_cell_input, advance): dci = decoder_cell_input mask = advance # this will be concatenated to x at every layer # we are conditioning on the attention from the previous time step and the agenda from the encoder x_augment = torch.cat([ rnn_state.source_attn.context, rnn_state.insert_attn.context, rnn_state.delete_attn.context, dci.agenda ], 1) hs, cs = [], [] x = dci.x # input word vector for layer in range(self.num_layers): rnn_cell = self.rnn_cells[layer] old_h, old_c = rnn_state.hs[layer], rnn_state.cs[layer] rnn_input = torch.cat([x, x_augment], 1) h, c = rnn_cell(rnn_input, (old_h, old_c)) h = gated_update(old_h, h, mask) c = gated_update(old_c, c, mask) hs.append(h) cs.append(c) if layer == 0: x = h # no skip connection on the first layer else: x = x + h # compute attention using bottom layer source_attn = self.source_attention(dci.source_embeds, hs[0]) insert_attn = self.insert_attention(dci.insert_embeds, hs[0]) delete_attn = self.delete_attention(dci.delete_embeds, hs[0]) if not self.no_insert_delete_attn: z = torch.cat([ x, source_attn.context, insert_attn.context, delete_attn.context ], 1) else: z = torch.cat([x, source_attn.context], 1) # has shape (batch_size, decoder_dim + encoder_dim + input_dim + input_dim) vocab_query_pos = self.vocab_projection_pos(z) vocab_query_neg = self.vocab_projection_neg(z) word_vocab = self.token_embedder.vocab word_embeds = self.token_embedder.embeds vocab_logit_pos = self.relu(torch.mm( vocab_query_pos, word_embeds.t())) # (batch_size, vocab_size) vocab_logit_neg = self.relu(torch.mm( vocab_query_neg, word_embeds.t())) # (batch_size, vocab_size) vocab_probs = self.vocab_softmax(vocab_logit_pos - vocab_logit_neg) # TODO(kelvin): prevent model from putting probability on UNK rnn_state = AttentionRNNState(hs, cs, source_attn, insert_attn, delete_attn) return DecoderCellOutput(rnn_state, vocab=word_vocab, vocab_probs=vocab_probs)
def forward(self, rnn_state, rnn_input, advance): rnn_input_embed = torch.cat([rnn_input.x, rnn_input.agenda], 1) h, c = self.rnn_cell(rnn_input_embed, (rnn_state.h, rnn_state.c)) # don't update if sequence has terminated h = gated_update(rnn_state.h, h, advance) c = gated_update(rnn_state.c, c, advance) query = self.linear(h) word_vocab = self.token_embedder.vocab word_embeds = self.token_embedder.embeds vocab_logits = torch.mm(query, word_embeds.t()) # (batch_size, vocab_size) vocab_probs = self.softmax(vocab_logits) # no attention over source, insert and delete embeds rnn_state = SimpleRNNState(h, c) return DecoderCellOutput(rnn_state, vocab=word_vocab, vocab_probs=vocab_probs)
def test_gated_update(): h = GPUVariable(torch.FloatTensor([ [1, 2, 3], [4, 5, 6], ])) h_new = GPUVariable(torch.FloatTensor([ [-1, 2, 3], [4, 8, 0], ])) update = GPUVariable(torch.FloatTensor([[0], [1] ])) # only update the second row out = gated_update(h, h_new, update) assert_tensor_equal(out, [[1, 2, 3], [4, 8, 0]])
def forward(self, rnn_state, decoder_cell_input, advance): dci = decoder_cell_input mask = advance # this will be concatenated to x at every layer # we are conditioning on the attention from the previous time step and the agenda from the encoder attn_contexts = torch.cat( [attn.context for attn in rnn_state.input_attns], 1) if self.disable_attention: x_augment = dci.agenda else: x_augment = torch.cat([attn_contexts, dci.agenda], 1) hs, cs = [], [] x = dci.x # input word vector for layer in range(self.num_layers): rnn_cell = self.rnn_cells[layer] old_h, old_c = rnn_state.hs[layer], rnn_state.cs[layer] rnn_input = torch.cat([x, x_augment], 1) h, c = rnn_cell(rnn_input, (old_h, old_c)) h = gated_update(old_h, h, mask) c = gated_update(old_c, c, mask) hs.append(h) cs.append(c) if layer == 0: x = h # no skip connection on the first layer else: x = x + h # Recurrent Neural Network Regularization # https://arxiv.org/pdf/1409.2329.pdf x = self.dropout(x) # note that dropout doesn't touch the recurrent connections # only connections going up the layers # compute attention using bottom layer input_attns = [ attn(dci.input_embeds[i], hs[0]) for i, attn in enumerate(self.input_attentions) ] attn_contexts = torch.cat([attn.context for attn in input_attns], 1) if self.disable_attention: z = x else: z = torch.cat([x, attn_contexts], 1) # has shape (batch_size, decoder_dim + encoder_dim + input_dim + input_dim) vocab_query_pos = self.vocab_projection_pos(z) vocab_query_neg = self.vocab_projection_neg(z) word_embeds = self.target_token_embedder.embeds vocab_logit_pos = self.relu(torch.mm( vocab_query_pos, word_embeds.t())) # (batch_size, vocab_size) vocab_logit_neg = self.relu(torch.mm( vocab_query_neg, word_embeds.t())) # (batch_size, vocab_size) vocab_probs = self.vocab_softmax(vocab_logit_pos - vocab_logit_neg) # TODO(kelvin): prevent model from putting probability on UNK rnn_state = AttentionRNNState(hs, cs, input_attns) # DynamicMultiVocabTokenEmbedder # NOTE: this is the same token embedder used by the SOURCE encoder dynamic_token_embedder = dci.token_embedder base_vocab = dynamic_token_embedder.base_vocab dynamic_vocabs = dynamic_token_embedder.dynamic_vocabs return AttentionDecoderCellOutput(rnn_state, base_vocab=base_vocab, dynamic_vocabs=dynamic_vocabs, vocab_probs=vocab_probs)