示例#1
0
    def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
        """
        See StdRNNDecoder._run_forward_pass() for description
        of arguments and return values.
        """
        # Additional args check.
        input_feed = self.state["input_feed"].squeeze(0)
        input_feed_batch, _ = input_feed.size()
        _, tgt_batch, _ = tgt.size()
        aeq(tgt_batch, input_feed_batch)
        # END Additional args check.

        # Initialize local and return variables.
        dec_outs = []
        attns = {"std": []}
        if self._copy:
            attns["copy"] = []
        if self._coverage:
            attns["coverage"] = []

        emb = self.embeddings(tgt)
        assert emb.dim() == 3  # len x batch x embedding_dim

        dec_state = self.state["hidden"]
        coverage = self.state["coverage"].squeeze(0) \
            if self.state["coverage"] is not None else None

        # Input feed concatenates hidden state with
        # input at every time step.
        for _, emb_t in enumerate(emb.split(1)):
            emb_t = emb_t.squeeze(0)
            decoder_input = torch.cat([emb_t, input_feed], 1)
            rnn_output, dec_state = self.rnn(decoder_input, dec_state)
            decoder_output, p_attn = self.attn(rnn_output,
                                               memory_bank.transpose(0, 1),
                                               memory_lengths=memory_lengths)

            decoder_output = self.dropout(decoder_output)
            input_feed = decoder_output

            dec_outs += [decoder_output]
            attns["std"] += [p_attn]

            # Update the coverage attention.
            if self._coverage:
                coverage = coverage + p_attn \
                    if coverage is not None else p_attn
                attns["coverage"] += [coverage]

            # Run the forward pass of the copy attention layer.
            if self._copy and not self._reuse_copy_attn:
                _, copy_attn = self.copy_attn(decoder_output,
                                              memory_bank.transpose(0, 1))
                attns["copy"] += [copy_attn]
            elif self._copy:
                attns["copy"] = attns["std"]
        # Return result.
        return dec_state, dec_outs, attns
示例#2
0
    def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
        """
        Private helper for running the specific RNN forward pass.
        Must be overriden by all subclasses.
        Args:
            tgt (LongTensor): a sequence of input tokens tensors
                                 [len x batch x nfeats].
            memory_bank (FloatTensor): output(tensor sequence) from the
                          encoder RNN of size (src_len x batch x hidden_size).
            state (FloatTensor): hidden state from the encoder RNN for
                                 initializing the decoder.
            memory_lengths (LongTensor): the source memory_bank lengths.
        Returns:
            dec_state (Tensor): final hidden state from the decoder.
            dec_outs ([FloatTensor]): an array of output of every time
                                     step from the decoder.
            attns (dict of (str, [FloatTensor]): a dictionary of different
                            type of attention Tensor array of every time
                            step from the decoder.
        """
        assert not self._copy  # TODO, no support yet.
        assert not self._coverage  # TODO, no support yet.

        # Initialize local and return variables.
        tgt = tgt.unsqueeze(2)
        attns = {}
        emb = self.embeddings(tgt)

        # Run the forward pass of the RNN.
        if isinstance(self.rnn, nn.GRU):
            rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0])
        else:
            rnn_output, dec_state = self.rnn(emb, self.state["hidden"])

        # Check
        tgt_len, tgt_batch, _ = tgt.size()
        output_len, output_batch, _ = rnn_output.size()
        aeq(tgt_len, output_len)
        aeq(tgt_batch, output_batch)
        # END

        # Calculate the attention.
        dec_outs, p_attn = self.attn(rnn_output.transpose(0, 1).contiguous(),
                                     memory_bank.transpose(0, 1),
                                     memory_lengths=memory_lengths)
        attns["std"] = p_attn

        dec_outs = self.dropout(dec_outs)
        return dec_state, dec_outs, attns
    def forward(self, inputs):
        """
        Computes the embedding for words and features.
        Args:
            inputs (`LongTensor`): index tensor `[len x batch]`
        Return:
            `FloatTensor`: word embedding `[len x batch x embeddededding_size]`
        """

        in_length, in_batch = inputs.size()
        #print("inputs shape: {}", inputs.shape)

        # aeq(nfeat, len(self.embedded_luts))

        embedded = self.embedding(inputs)
        
        #print("self.droput_ratio: %f" % self.dropout.p)

        #print("embedded shape: {}".format(embedded.shape))
        #print("embedded device: {}".format(embedded.device))
        #print("embedded: {}".format(embedded))

        if self.dropout is not None:
            embedded = self.dropout(embedded)

        out_length, out_batch, embedded_size = embedded.size()

        aeq(in_length, out_length)
        aeq(in_batch, out_batch)
        aeq(embedded_size, self.embedding_size)

        return embedded
    def forward(self, inputs, is_dropout=True):
        """
        Computes the embedding for words and features.
        Args:
            inputs (`LongTensor`): index tensor `[len x batch]`
        Return:
            `FloatTensor`: word embedding `[len x batch x embeddededding_size]`
        """

        dim = inputs.dim()
        if dim == 2:
            # with batch
            in_length, in_batch = inputs.size()

        embedded = self.embedding(inputs)

        if self.dropout is not None and is_dropout:
            embedded = self.dropout(embedded)

        if dim == 2:
            out_length, out_batch, embedded_size = embedded.size()
            aeq(in_length, out_length)
            aeq(in_batch, out_batch)
            aeq(embedded_size, self.embedding_size)

        return embedded
示例#5
0
    def score(self, decoder_output, encoder_outputs):
        """
        Args:
          decoder_output (`FloatTensor`): sequence of queries `[batch_sizse x tgt_len x hidden_size]`
          encoder_outputs (`FloatTensor`): sequence of sources `[batch_sizse x src_len x hidden_size]`
        Returns:
          :obj:`FloatTensor`:
           raw attention scores (unnormalized) for each src index
          `[batch_sizse x tgt_len x src_len]`
        """

        # Check decoder_output sizes
        src_batch, src_len, src_dim = encoder_outputs.size()
        tgt_batch, tgt_len, tgt_dim = decoder_output.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.hidden_size, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_type == "general":
                #  h_t_ = decoder_output.view(tgt_batch * tgt_len, tgt_dim)
                #  h_t_ = self.linear_in(h_t_)
                #  decoder_output = h_t_.view(tgt_batch, tgt_len, tgt_dim)
                decoder_output = self.linear_in(decoder_output)

            # (batch_sizse, t_len, d) x (batch_sizse, d, s_len) --> (batch_sizse, t_len, s_len)
            # [batch_sizse, t_len, s_len]
            return torch.bmm(decoder_output, encoder_outputs.transpose(1, 2))
        else:
            hidden_size = self.hidden_size
            wq = self.linear_query(decoder_output.view(-1, hidden_size))
            wq = wq.view(tgt_batch, tgt_len, 1, hidden_size)
            wq = wq.expand(tgt_batch, tgt_len, src_len, hidden_size)

            uh = self.linear_context(encoder_outputs.contiguous().view(
                -1, hidden_size))
            uh = uh.view(src_batch, 1, src_len, hidden_size)
            uh = uh.expand(src_batch, tgt_len, src_len, hidden_size)

            # (batch_sizse, t_len, s_len, d)
            wquh = self.tanh(wq + uh)

            return self.v(wquh.view(-1,
                                    hidden_size)).view(tgt_batch, tgt_len,
                                                       src_len)
示例#6
0
    def score(self, h_t, h_s):
        """
        Args:
          h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]`
          h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]`
        Returns:
          :obj:`FloatTensor`:
           raw attention scores (unnormalized) for each src index
          `[batch x tgt_len x src_len]`
        """

        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_type == "general":
                h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim)
                h_t_ = self.linear_in(h_t_)
                h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
            h_s_ = h_s.transpose(1, 2)
            # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, dim)
            wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

            uh = self.linear_context(h_s.contiguous().view(-1, dim))
            uh = uh.view(src_batch, 1, src_len, dim)
            uh = uh.expand(src_batch, tgt_len, src_len, dim)

            # (batch, t_len, s_len, d)
            wquh = torch.tanh(wq + uh)

            return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
示例#7
0
 def _check_args(self, src, lengths=None, hidden=None):
     _, n_batch, _ = src.size()
     if lengths is not None:
         n_batch_, = lengths.size()
         aeq(n_batch, n_batch_)
示例#8
0
 def _check_args(self, tgt, memory_bank, state):
     assert isinstance(state, RNNDecoderState)
     tgt_len, tgt_batch = tgt.size()
     _, memory_batch, _ = memory_bank.size()
     aeq(tgt_batch, memory_batch)
示例#9
0
    def forward(self,
                decoder_output,
                encoder_outputs,
                encoder_inputs_length=None):
        """
        Args:
          decoder_output (`FloatTensor`): query vectors `[batch_sizse x tgt_len x hidden_size]`
          memory_bank (`FloatTensor`): source vectors `[batch_sizse x src_len x hidden_size]`
          encoder_inputs_length (`LongTensor`): the source context lengths `[batch_size]`
        Returns:
          (`FloatTensor`, `FloatTensor`):
          * Computed vector `[tgt_len x batch_sizse x hidden_size]`
          * Attention distribtutions for each query
             `[tgt_len x batch_sizse x src_len]`
        """

        # one step decoder_output
        if decoder_output.dim() == 2:
            one_step = True
            # insert one dimension
            decoder_output = decoder_output.unsqueeze(1)
        else:
            one_step = False

        batch_sizse, sourceL, hidden_size = encoder_outputs.size()
        batch_size_, targetL, hidden_sizse_ = decoder_output.size()

        aeq(batch_sizse, batch_size_)
        aeq(hidden_size, hidden_sizse_)
        aeq(self.hidden_size, hidden_size)

        # compute attention scores, as in Luong et al.
        align = self.score(decoder_output,
                           encoder_outputs)  #[batch_size, t_len, s_len]

        if encoder_inputs_length is not None:
            # obtain mask for memory_lenghts
            mask = sequence_mask(encoder_inputs_length)
            mask = mask.to(device=encoder_outputs.device)

            mask = mask.unsqueeze(1)  # Make it broadcastable.

            # Fills elements of self tensor with value where mask is one. masked_fill_(mask, value)
            align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.softmax(align)  #

        # each context vector c_t is the weighted average
        # over all the source hidden states
        context_vecotr = torch.bmm(
            align_vectors, encoder_outputs)  #[batch_size, t_len , hidden_size]

        # concatenate
        concated_cv = torch.cat((context_vecotr, decoder_output),
                                dim=2)  #[batch_size, t_len, 2*hidden_size]
        attn_h = self.linear_out(
            concated_cv)  #[batch_size, t_len, hidden_size]

        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)  # tanh activation

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_size_, hidden_sizse_ = attn_h.size()
            aeq(batch_sizse, batch_size_)
            aeq(hidden_size, hidden_sizse_)
            batch_size_, sourceL_ = align_vectors.size()
            aeq(batch_sizse, batch_size_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(
                0, 1).contiguous()  # [t_len, batch_size, hidden_size]
            align_vectors = align_vectors.transpose(
                0, 1).contiguous()  # [t_len, batch_size, s_len]

            # Check output sizes
            targetL_, batch_size_, hidden_sizse_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch_sizse, batch_size_)
            aeq(hidden_size, hidden_sizse_)
            targetL_, batch_size_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch_sizse, batch_size_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors
 def _check_args(self, input, lengths=None, hidden=None):
     s_len, n_batch = input.size()
     if lengths is not None:
         n_batch_, = lengths.size()
         aeq(n_batch, n_batch_)
示例#11
0
    def _run_forward_pass(self,
                          inputs,
                          encoder_outputs,
                          decoder_state,
                          encoder_inputs_length=None):
        """
        Private helper for running the specific RNN forward pass.
        Must be overrided by all subclasses.
        Args:
            inputs (LongTensor): a sequence of input tokens tensors
                                 [inputs_len x batch].
            encoder_outputs (FloatTensor): output(tensor sequence) from the encoder
                        RNN of size (src_len x batch x hidden_size).
            decoder_state (FloatTensor): hidden decoder_state from the encoder RNN for
                                 initializing the decoder.
            encoder_inputs_length (LongTensor): the source encoder_outputs lengths.
        Returns:
            decoder_final (Variable): final hidden decoder_state from the decoder.
            decoder_outputs ([FloatTensor]): an array of output of every time
                                     step from the decoder.
            attns (dict of (str, [FloatTensor]): a dictionary of different
                            type of attention Tensor array of every time
                            step from the decoder.
        """

        # Initialize local and return variables.
        attns = {}

        embedded = self.embedding(inputs)

        # Run the forward pass of the RNN.
        if isinstance(self.rnn, nn.GRU) or isinstance(self.rnn, nn.RNN):
            decoder_output, decoder_final = self.rnn(embedded,
                                                     decoder_state.hidden[0])
        else:
            # LSTM
            decoder_output, decoder_final = self.rnn(embedded,
                                                     decoder_state.hidden)

        # Check
        inputs_len, tgt_batch = inputs.size()
        output_len, output_batch, _ = decoder_output.size()

        aeq(inputs_len, output_len)
        aeq(tgt_batch, output_batch)

        # Calculate the attention.
        if self.attn_type is not None:
            # attention forward
            #  decoder_output, p_attn = self.attn(
            #  decoder_output.transpose(0, 1),
            #  encoder_outputs.transpose(0, 1))
            # decoder_output -> [1, batch_size, hidden_size], encoder_outputs ->
            # [1, batch_size, hidden_sizes] -> [batch_size, 1, hidden_size]
            decoder_output, p_attn = self.attn(decoder_output.transpose(0, 1),
                                               encoder_outputs.transpose(0, 1),
                                               encoder_inputs_length)
            attns["std"] = p_attn
        else:
            decoder_output = decoder_output

        # dropout
        decoder_output = self.dropout(decoder_output)

        return decoder_final, decoder_output, attns
示例#12
0
 def _check_args(self, inputs, encoder_outputs, decoder_state):
     assert isinstance(decoder_state, RNNDecoderState)
     inputs_len, tgt_batch = inputs.size()
     _, memory_batch, _ = encoder_outputs.size()
     aeq(tgt_batch, memory_batch)
示例#13
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """
        Args:
          source (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)
        Returns:
          (`FloatTensor`, `FloatTensor`):
          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)

        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        align_vectors = F.softmax(align.view(batch * target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return attn_h, align_vectors