def forward(self, input_embeds_list):
        """

        Args:
            input_embeds_list (list[SequenceBatchElement]): where each element is of shape (batch_size, input_dim)

        Returns:
            hidden_states_list (list[SequenceBatchElement]) where each element is (batch_size, hidden_dim)
        """
        batch_size = input_embeds_list[0].values.size()[0]

        h = tile_state(self.h0, batch_size)  # (batch_size, hidden_dim)
        c = tile_state(self.c0, batch_size)  # (batch_size, hidden_dim)

        hidden_states_list = []

        for t, x in enumerate(input_embeds_list):
            # x.values has shape (batch_size, input_dim)
            # x.mask has shape (batch_size, 1)
            h_new, c_new = self.rnn_cell(x.values, (h, c))
            h = gated_update(h, h_new, x.mask)
            c = gated_update(c, c_new, x.mask)
            hidden_states_list.append(
                SequenceBatchElement(self.dropout(h), x.mask))

        return hidden_states_list
Ejemplo n.º 2
0
    def forward(self, rnn_state, rnn_input, advance):
        x = torch.cat([rnn_input.x, rnn_input.agenda], 1)
        hs, cs = [], []
        for layer in range(self.num_layers):
            rnn_cell = self.rnn_cells[layer]

            # collect the h, c belonging to the previous time-step at the corresponding depth
            h_prev_t, c_prev_t = rnn_state.hs[layer], rnn_state.cs[layer]

            # forward pass and masking
            h, c = rnn_cell(x, (h_prev_t, c_prev_t))
            h = gated_update(h_prev_t, h, advance)
            c = gated_update(c_prev_t, c, advance)
            hs.append(h)
            cs.append(c)

            if layer == 0:
                x = h  # no skip connection on the first layer
            else:
                x = x + h

        query = self.linear(x)
        word_vocab = self.token_embedder.vocab
        word_embeds = self.token_embedder.embeds
        vocab_logits = torch.mm(query,
                                word_embeds.t())  # (batch_size, vocab_size)
        vocab_probs = self.softmax(vocab_logits)

        rnn_state = MultilayeredRNNState(hs, cs)

        return DecoderCellOutput(rnn_state,
                                 vocab=word_vocab,
                                 vocab_probs=vocab_probs)
Ejemplo n.º 3
0
    def forward(self, rnn_state, decoder_cell_input, advance):
        dci = decoder_cell_input
        mask = advance

        # this will be concatenated to x at every layer
        # we are conditioning on the attention from the previous time step and the agenda from the encoder
        x_augment = torch.cat([
            rnn_state.source_attn.context, rnn_state.insert_attn.context,
            rnn_state.delete_attn.context, dci.agenda
        ], 1)

        hs, cs = [], []
        x = dci.x  # input word vector
        for layer in range(self.num_layers):
            rnn_cell = self.rnn_cells[layer]
            old_h, old_c = rnn_state.hs[layer], rnn_state.cs[layer]
            rnn_input = torch.cat([x, x_augment], 1)
            h, c = rnn_cell(rnn_input, (old_h, old_c))
            h = gated_update(old_h, h, mask)
            c = gated_update(old_c, c, mask)
            hs.append(h)
            cs.append(c)

            if layer == 0:
                x = h  # no skip connection on the first layer
            else:
                x = x + h

        # compute attention using bottom layer
        source_attn = self.source_attention(dci.source_embeds, hs[0])
        insert_attn = self.insert_attention(dci.insert_embeds, hs[0])
        delete_attn = self.delete_attention(dci.delete_embeds, hs[0])
        if not self.no_insert_delete_attn:
            z = torch.cat([
                x, source_attn.context, insert_attn.context,
                delete_attn.context
            ], 1)
        else:
            z = torch.cat([x, source_attn.context], 1)

        # has shape (batch_size, decoder_dim + encoder_dim + input_dim + input_dim)

        vocab_query_pos = self.vocab_projection_pos(z)
        vocab_query_neg = self.vocab_projection_neg(z)
        word_vocab = self.token_embedder.vocab
        word_embeds = self.token_embedder.embeds
        vocab_logit_pos = self.relu(torch.mm(
            vocab_query_pos, word_embeds.t()))  # (batch_size, vocab_size)
        vocab_logit_neg = self.relu(torch.mm(
            vocab_query_neg, word_embeds.t()))  # (batch_size, vocab_size)
        vocab_probs = self.vocab_softmax(vocab_logit_pos - vocab_logit_neg)
        # TODO(kelvin): prevent model from putting probability on UNK

        rnn_state = AttentionRNNState(hs, cs, source_attn, insert_attn,
                                      delete_attn)

        return DecoderCellOutput(rnn_state,
                                 vocab=word_vocab,
                                 vocab_probs=vocab_probs)
    def forward(self, rnn_state, rnn_input, advance):
        rnn_input_embed = torch.cat([rnn_input.x, rnn_input.agenda], 1)
        h, c = self.rnn_cell(rnn_input_embed, (rnn_state.h, rnn_state.c))

        # don't update if sequence has terminated
        h = gated_update(rnn_state.h, h, advance)
        c = gated_update(rnn_state.c, c, advance)

        query = self.linear(h)
        word_vocab = self.token_embedder.vocab
        word_embeds = self.token_embedder.embeds
        vocab_logits = torch.mm(query, word_embeds.t())  # (batch_size, vocab_size)
        vocab_probs = self.softmax(vocab_logits)

        # no attention over source, insert and delete embeds
        rnn_state = SimpleRNNState(h, c)

        return DecoderCellOutput(rnn_state, vocab=word_vocab, vocab_probs=vocab_probs)
def test_gated_update():
    h = GPUVariable(torch.FloatTensor([
        [1, 2, 3],
        [4, 5, 6],
    ]))
    h_new = GPUVariable(torch.FloatTensor([
        [-1, 2, 3],
        [4, 8, 0],
    ]))
    update = GPUVariable(torch.FloatTensor([[0], [1]
                                            ]))  # only update the second row

    out = gated_update(h, h_new, update)

    assert_tensor_equal(out, [[1, 2, 3], [4, 8, 0]])
    def forward(self, rnn_state, decoder_cell_input, advance):
        dci = decoder_cell_input
        mask = advance

        # this will be concatenated to x at every layer
        # we are conditioning on the attention from the previous time step and the agenda from the encoder
        attn_contexts = torch.cat(
            [attn.context for attn in rnn_state.input_attns], 1)
        if self.disable_attention:
            x_augment = dci.agenda
        else:
            x_augment = torch.cat([attn_contexts, dci.agenda], 1)

        hs, cs = [], []
        x = dci.x  # input word vector
        for layer in range(self.num_layers):
            rnn_cell = self.rnn_cells[layer]
            old_h, old_c = rnn_state.hs[layer], rnn_state.cs[layer]
            rnn_input = torch.cat([x, x_augment], 1)
            h, c = rnn_cell(rnn_input, (old_h, old_c))
            h = gated_update(old_h, h, mask)
            c = gated_update(old_c, c, mask)
            hs.append(h)
            cs.append(c)

            if layer == 0:
                x = h  # no skip connection on the first layer
            else:
                x = x + h

            # Recurrent Neural Network Regularization
            # https://arxiv.org/pdf/1409.2329.pdf
            x = self.dropout(x)
            # note that dropout doesn't touch the recurrent connections
            # only connections going up the layers

        # compute attention using bottom layer
        input_attns = [
            attn(dci.input_embeds[i], hs[0])
            for i, attn in enumerate(self.input_attentions)
        ]
        attn_contexts = torch.cat([attn.context for attn in input_attns], 1)
        if self.disable_attention:
            z = x
        else:
            z = torch.cat([x, attn_contexts], 1)

        # has shape (batch_size, decoder_dim + encoder_dim + input_dim + input_dim)
        vocab_query_pos = self.vocab_projection_pos(z)
        vocab_query_neg = self.vocab_projection_neg(z)
        word_embeds = self.target_token_embedder.embeds
        vocab_logit_pos = self.relu(torch.mm(
            vocab_query_pos, word_embeds.t()))  # (batch_size, vocab_size)
        vocab_logit_neg = self.relu(torch.mm(
            vocab_query_neg, word_embeds.t()))  # (batch_size, vocab_size)
        vocab_probs = self.vocab_softmax(vocab_logit_pos - vocab_logit_neg)
        # TODO(kelvin): prevent model from putting probability on UNK

        rnn_state = AttentionRNNState(hs, cs, input_attns)

        # DynamicMultiVocabTokenEmbedder
        # NOTE: this is the same token embedder used by the SOURCE encoder
        dynamic_token_embedder = dci.token_embedder
        base_vocab = dynamic_token_embedder.base_vocab
        dynamic_vocabs = dynamic_token_embedder.dynamic_vocabs

        return AttentionDecoderCellOutput(rnn_state,
                                          base_vocab=base_vocab,
                                          dynamic_vocabs=dynamic_vocabs,
                                          vocab_probs=vocab_probs)