Example #1
0
    def score(self, h_t, h_s):
        """
        Args:
            h_t (FloatTensor): sequence of queries [batch x tgt_len x h_t_dim]
            h_s (FloatTensor): sequence of sources [batch x src_len x h_s_dim]
        Returns:
            raw attention scores for each src index [batch x tgt_len x src_len]
        """

        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        utils.aeq(src_batch, tgt_batch)
        #utils.aeq(src_dim, tgt_dim)

        if self.attn_type == "bilinear":
            h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim)
            h_t_ = self.linear_in(h_t_)
            h_t = h_t_.view(tgt_batch, tgt_len, src_dim)
            h_s_ = h_s.transpose(1, 2)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, dim)
            wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

            uh = self.linear_context(h_s.contiguous().view(-1, dim))
            uh = uh.view(src_batch, 1, src_len, dim)
            uh = uh.expand(src_batch, tgt_len, src_len, dim)

            # (batch, t_len, s_len, d)
            wquh = torch.tanh(wq + uh)

            return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
Example #2
0
    def _run_forward_pass(self, input, context, state, mask):
        """
        Only used for beam search.
        Only compatible with runs with attention. Todo: implementation without attention.
        Private helper for running the specific RNN forward pass.
        Must be overriden by all subclasses.
        Args:
            input (LongTensor): a sequence of input tokens tensors
                                of size (len x batch x nfeats).
            context (FloatTensor): output(tensor sequence) from the encoder
                        RNN of size (src_len x batch x hidden_size).
            state (FloatTensor): hidden state from the encoder RNN for
                                 initializing the decoder.
        Returns:
            hidden (Variable): final hidden state from the decoder.
            outputs ([FloatTensor]): an array of output of every time
                                     step from the decoder.
            attns (dict of (str, [FloatTensor]): a dictionary of different
                            type of attention Tensor array of every time
                            step from the decoder.
            coverage (FloatTensor, optional): coverage from the decoder.
        """
        # Initialize local and return variables.
        attns = {"std": []}
        coverage = None
        emb = self.embedding(input)

        if emb.size(2) != state.hidden[0].size(2):
            state.hidden = [self.project_encoder(state.hidden[0])]

        # Run the forward pass of the RNN.
        if isinstance(self.rnn, StackedGRU):
            rnn_output, hidden = self.rnn(emb, state.hidden[0])
        else:
            rnn_output, hidden = self.rnn(emb, state.hidden)
        # Result Check
        input_len, input_batch, _ = input.size()
        output_len, output_batch, _ = rnn_output.size()
        aeq(input_len, output_len)
        aeq(input_batch, output_batch)
        # END Result Check

        # Calculate the attention.
        attn_outputs, attn_scores = self.attention(
            rnn_output.contiguous(),  # (output_len, batch, d)
            context.transpose(0, 1).contiguous(),  # (contxt_len, batch, d)
            mask
        )
        attns["std"] = attn_scores
        outputs = attn_outputs  # (input_len, batch, d)

        # Return result.
        return hidden, outputs, attns, coverage
Example #3
0
    def forward(self, inp):
        """
        Return the embeddings for words, and features if there are any.
        Args:
            inp (LongTensor): batch x len x nfeat
        Return:
            emb (Tensor): batch x len x self.embedding_size
        """
        if inp.dim() == 2:
            # batch x len
            emb = self.word_lookup_table(inp)
            return emb

        in_batch, in_length, nfeat = inp.size()
        # 特征数量应与Embedding个数相同
        aeq(nfeat, len(self.emb_luts))

        if len(self.emb_luts) == 1:
            emb = self.word_lookup_table(inp.squeeze(2))
        else:
            feat_inputs = (feat.squeeze(2) for feat in inp.split(1, dim=2))
            features = [
                lut(feat) for lut, feat in zip(self.emb_luts, feat_inputs)
            ]
            emb = self.merge(features)

        out_batch, out_length, emb_size = emb.size()
        aeq(in_batch, out_batch)
        aeq(in_length, out_length)
        aeq(emb_size, self.embedding_size())

        return emb
Example #4
0
    def _run_forward_pass(self, ph_sel, phrase_bank, phrase_lengths=None):
        """
        Args:
            ph_sum_emb (FloatTensor): a tensor of phrase embeddings for each
                                      RR sentence
                                     [batch x max_sent_num x ph_emb_dim]
            phrase_bank (FloatTensor): embeddings for phrase collections
                            [batch x len x nfeats]
            phrase_lengths (LongTensor): the lengths of phrase collections
        Returns:
            dec_state (Tensor): final hidden state from the decoder.
            dec_outs ([FloatTensor]): an array of output of every time step
                                      from the decoder.
            ph_attns ([FloatTensor]): phrase attention Tensor array of every
                                      time step from the decoder.
        """
        ph_sel_emb = self.embedding(ph_sel)
        ph_sel_emb = torch.sum(ph_sel_emb, -2)
        ph_sum_emb = torch.sum(ph_sel_emb, -2)
        rnn_output, dec_state = self.LSTM(ph_sum_emb, self.state["hidden"])
        self.rnn_output = rnn_output

        ph_batch, ph_len, _ = ph_sum_emb.size()

        output_batch, output_len, _ = rnn_output.size()

        utils.aeq(ph_len, output_len)
        utils.aeq(ph_batch, output_batch)

        dec_outs, ph_attn, ph_attn_raw = self.ph_attn(
            rnn_output.contiguous(),
            phrase_bank.contiguous(),
            memory_lengths=phrase_lengths,
            use_softmax=False)

        readouts = self.readout(dec_outs)

        # dec_outs: [batch_size x max_sent_num x ph_emb_dim]
        # readouts: [batch_size x max_sent_num x ph_vocab_size]

        return dec_state, dec_outs, ph_attn, ph_attn_raw, readouts
Example #5
0
    def original_forward(self, input, context, state, mask):
        """
        Only used for beam search. Forward through the decoder.
        Args:
            input (LongTensor): a sequence of input tokens tensors
                                of size (len x batch x nfeats).
            context (FloatTensor): output(tensor sequence) from the encoder
                        RNN of size (src_len x batch x hidden_size).
            state (FloatTensor): hidden state from the encoder RNN for
                                 initializing the decoder.
        Returns:
            outputs (FloatTensor): a Tensor sequence of output from the decoder
                                   of shape (len x batch x hidden_size).
            state (FloatTensor): final hidden state from the decoder.
            attns (dict of (str, FloatTensor)): a dictionary of different
                                type of attention Tensor from the decoder
                                of shape (src_len x batch).
        """
        # Args Check
        assert isinstance(state, RNNDecoderState)
        input_len, input_batch, _ = input.size()
        contxt_len, contxt_batch, _ = context.size()
        aeq(input_batch, contxt_batch)
        # END Args Check

        # Run the forward pass of the RNN.
        hidden, outputs, attns, coverage = self._run_forward_pass(input, context, state, mask)

        # Update the state with the result.
        final_output = outputs[-1]
        state.update_state(hidden, final_output.unsqueeze(0), coverage.unsqueeze(0) if coverage is not None else None)

        # Concatenates sequence of tensors along a new dimension.
        outputs = torch.stack(outputs)
        for k in attns:
            attns[k] = torch.stack(attns[k])

        return outputs, state, attns
Example #6
0
    def _run_forward_pass(self, ph_sel, phrase_bank, phrase_lengths=None):
        """
        Args:
            ph_sel (batch_size x max_sent_num x max_ph_per_sent x max_ph_len):
                token ids for selected phrases
            phrase_bank (batch_size x max_ph_bank_size x dim): embeddings for
                phrases in each phrase bank for each sample
            phrase_lengths (batch_size): size of phrase bank for each sample

        Returns:
            dec_state (Tuple of C and H): final hidden state from the decoder.
            dec_outs (batch_size x max_sent_num x dim): an array of output of every time step
                                      from the decoder.
            ph_attns (batch_size x max_sent_num x max_ph_bank_size): phrase
                attention Tensor array of every time step from the decoder.
        """
        ph_sel_emb = self.embedding(ph_sel)
        ph_sel_emb = torch.sum(ph_sel_emb, -2) # sum over all tokens in each phrase
        ph_sum_emb = torch.sum(ph_sel_emb, -2) # sum over all phrases in each sentence
        rnn_output, dec_state = self.LSTM(ph_sum_emb, self.state["hidden"])
        self.rnn_output = rnn_output

        batch_size, max_sent_num, _ = ph_sum_emb.size()

        output_batch, output_len, _ = rnn_output.size()

        utils.aeq(max_sent_num, output_len)
        utils.aeq(batch_size, output_batch)

        dec_outs, ph_attn_probs, ph_attn_logits = self.ph_attn(
            rnn_output.contiguous(),
            phrase_bank.contiguous(),
            memory_lengths=phrase_lengths,
            use_softmax=False
        )

        stype_logits = self.readout(dec_outs)
        return dec_state, dec_outs, ph_attn_probs, ph_attn_logits, stype_logits
Example #7
0
    def forward(self,
                input,
                context,
                context_lengths=None,
                context_max_len=None):
        """
        input (FloatTensor): batch x tgt_len x dim: decoder's rnn's output.
        context (FloatTensor): batch x src_len x dim: src hidden states
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = context.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)

        # compute attention scores, as in Luong et al.
        align = self.score(input, context)

        if context_lengths is not None:
            mask = self.sequence_mask(context_lengths, context_max_len)
            mask = mask.unsqueeze(1)
            # (bz, max_len) -> (bz, 1, max_len), so mask can broadcast
            align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = torch.softmax(align, dim=-1)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, context)
        # concatenate
        concat_c = torch.cat([c, input], -1)
        # linear_out
        if self.linear_out is None:
            attn_h = concat_c
        else:
            attn_h = self.linear_out(concat_c)
            if self.attn_type in ["general", "dot"]:
                attn_h = self.tanh(attn_h)
        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)
        # (batch, targetL, dim_), (batch, targetL, sourceL)
        return attn_h, align_vectors
    def score(self, h_t, h_s):
        """
        h_t (FloatTensor): batch x tgt_len x dim
        h_s (FloatTensor): batch x src_len x dim
        returns scores (FloatTensor): batch x tgt_len x src_len:
            raw attention scores for each src index
        """

        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_hidden > 0:
                h_t = self.transform_in(h_t)
                h_s = self.transform_in(h_s)
            if self.attn_type == "general":
                h_t = self.linear_in(h_t)
            h_s_ = h_s.transpose(1, 2)
            # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, dim)
            wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

            uh = self.linear_context(h_s.contiguous().view(-1, dim))
            uh = uh.view(src_batch, 1, src_len, dim)
            uh = uh.expand(src_batch, tgt_len, src_len, dim)

            # (batch, t_len, s_len, d)
            wquh = self.tanh(wq + uh)

            return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
Example #9
0
    def forward(self,
                query,
                memory_bank,
                memory_lengths=None,
                use_softmax=True):
        """
        Args:
            query (FloatTensor): query vectors [batch x tgt_len x dim]
            memory_bank (FloatTensor): source vectors [batch x src_len x dim]
            memory_lengths (LongTensor): source context lengths [batch]
            use_softmax (bool): use softmax to produce alignment score,
                otherwise use sigmoid for each individual one
        Returns:
            (FloatTensor, FloatTensor)
            computed attention weighted average: [batch x tgt_len x dim]
            attention distribution: [batch x tgt_len x src_len]
        """
        '''
        print("memory_bank:")
        print(memory_bank.size())
        '''

        if query.dim == 2:
            one_step = True
            query = query.unsqueeze(1)
        else:
            one_step = False

        src_batch, src_len, src_dim = memory_bank.size()
        query_batch, query_len, query_dim = query.size()
        utils.aeq(src_batch, query_batch)
        #utils.aeq(src_dim, query_dim)

        align = self.score(query, memory_bank)
        '''
        print("memory_lengths:")
        print(memory_lengths.size())
        print(memory_lengths)
        '''

        if memory_lengths is not None:
            mask = utils.sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)
            align.masked_fill_(1 - mask, -float('inf'))
        '''
        print("align:")
        print(align)
        print(align.size())
        '''
        if use_softmax:
            align_vectors = F.softmax(
                align.view(src_batch * query_len, src_len), -1)
            align_vectors = align_vectors.view(src_batch, query_len, src_len)
        else:
            align_vectors = F.sigmoid(align)
        '''
        print("align after normalize:")
        print(align_vectors)
        print("align_vectors:")
        print(align_vectors)
        print(align_vectors.size())
        print("memory_bank:")
        print(memory_bank)
        print(memory_bank.size())
        '''

        c = torch.bmm(align_vectors, memory_bank)
        # c is the attention weighted context representation
        # [batch x tgt_len x hidden_size]
        '''
        print("c:")
        print(c.size())
        print("query:")
        print(query.size())
        '''

        concat_c = torch.cat([c, query], 2).view(src_batch * query_len,
                                                 src_dim + query_dim)
        '''
        print("concat_c:")
        print(concat_c.size())
        '''
        attn_h = self.linear_out(concat_c).view(src_batch, query_len,
                                                query_dim)
        if self.attn_type == "bilinear":
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            batch_, dim_ = attn_h.size()
            utils.aeq(src_batch, batch_)
            utils.aeq(src_dim, dim_)
            batch_, src_l_ = align_vectors.size()
            utils.aeq(src_batch, batch_)
            utils.aeq(src_len, src_l_)

        else:

            batch_, target_l_, dim_ = attn_h.size()
            utils.aeq(target_l_, query_len)
            utils.aeq(batch_, query_batch)
            utils.aeq(dim_, query_dim)

            batch_, target_l_, source_l_ = align_vectors.size()
            utils.aeq(target_l_, query_len)
            utils.aeq(batch_, query_batch)
            utils.aeq(source_l_, src_len)

        return attn_h, align_vectors, align
Example #10
0
    def forward(self,
                query,
                memory_bank,
                memory_lengths=None,
                use_softmax=True):
        """
        Args:
            query (FloatTensor): query vectors [batch x tgt_len x q_dim]
            memory_bank (FloatTensor): source vectors [batch x src_len x k_dim]
            memory_lengths (LongTensor): source context lengths [batch]
            use_softmax (bool): use softmax to produce alignment score,
                otherwise use sigmoid for keyphrase selection
        Returns:
            attn_h (FloatTensor, batch x tgt_len x k_dim): weighted value vectors after attention
            attn_vectors (FloatTensor, batch x tgt_len x src_len) : normalized attention scores
            align (FloatTensor, batch x tgt_len x src_len): raw attention scores used for loss calculation
        """

        if query.dim == 2:
            one_step = True
            query = query.unsqueeze(1)
        else:
            one_step = False

        src_batch, src_len, src_dim = memory_bank.size()
        query_batch, query_len, query_dim = query.size()
        utils.aeq(src_batch, query_batch)
        #utils.aeq(src_dim, query_dim)

        align = self.score(query, memory_bank)

        if memory_lengths is not None:
            mask = utils.sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1).long()
            align.masked_fill_((1 - mask).bool(), -float('inf'))

        if use_softmax:
            align_vectors = self.softmax(
                align.view(src_batch * query_len, src_len))
            align_vectors = align_vectors.view(src_batch, query_len, src_len)
        else:
            align_vectors = self.sigmoid(align)

        c = torch.bmm(align_vectors, memory_bank)
        # c is the attention weighted context representation
        # [batch x tgt_len x hidden_size]

        concat_c = torch.cat([c, query], 2).view(src_batch * query_len,
                                                 src_dim + query_dim)
        attn_h = self.linear_out(concat_c).view(src_batch, query_len,
                                                query_dim)
        attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            batch_, dim_ = attn_h.size()
            utils.aeq(src_batch, batch_)
            utils.aeq(src_dim, dim_)
            batch_, src_l_ = align_vectors.size()
            utils.aeq(src_batch, batch_)
            utils.aeq(src_len, src_l_)

        else:

            batch_, target_l_, dim_ = attn_h.size()
            utils.aeq(target_l_, query_len)
            utils.aeq(batch_, query_batch)
            utils.aeq(dim_, query_dim)

            batch_, target_l_, source_l_ = align_vectors.size()
            utils.aeq(target_l_, query_len)
            utils.aeq(batch_, query_batch)
            utils.aeq(source_l_, src_len)

        return attn_h, align_vectors, align