Beispiel #1
0
 def _set_input_buffer(self, incremental_state, buffer):
     utils.set_incremental_state(
         self,
         incremental_state,
         'f',
         buffer,
     )
Beispiel #2
0
 def _set_input_buffer(self, incremental_state, buffer):
     set_incremental_state(
         self,
         incremental_state,
         'attn_state',
         buffer,
     )
Beispiel #3
0
    def forward(self,
                x,
                encoder_out=None,
                encoder_padding_mask=None,
                incremental_state=None,
                **kwargs):
        layer_norm_training = kwargs.get('layer_norm_training', None)
        if layer_norm_training is not None:
            self.layer_norm1.training = layer_norm_training
            self.layer_norm2.training = layer_norm_training
        self.lstm.flatten_parameters()
        if incremental_state is not None:
            x = x[-1:, :, :]
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells = cached_state
        else:
            prev_hiddens = encoder_out.mean(dim=0, keepdim=True)
            prev_cells = encoder_out.mean(dim=0, keepdim=True)

        residual = x
        x = self.layer_norm1(x)
        x, hidden = self.lstm(x, (prev_hiddens, prev_cells))
        hiddens, cells = hidden
        x = residual + x

        x = self.layer_norm2(x)
        x, attn = self.attention(
            query=x,
            key=encoder_out,
            value=encoder_out,
            key_padding_mask=encoder_padding_mask,
            incremental_state=incremental_state,
            static_kv=True,
            enc_dec_attn_constraint_mask=utils.get_incremental_state(
                self, incremental_state, 'enc_dec_attn_constraint_mask'))
        x = F.dropout(x, self.dropout, training=self.training)

        if incremental_state is not None:
            #prev_hiddens = torch.cat((prev_hiddens, hiddens), dim=0)
            #prev_cells = torch.cat((prev_cells, cells), dim=0)
            prev_hiddens = hiddens
            prev_cells = cells
            utils.set_incremental_state(
                self,
                incremental_state,
                'cached_state',
                (prev_hiddens, prev_cells),
            )

        x = residual + x
        attn_logits = attn[1]
        #if len(attn_logits.size()) > 3:
        #    attn_logits = attn_logits[:, 0]
        return x, attn_logits
Beispiel #4
0
 def clear_buffer(self,
                  input,
                  encoder_out=None,
                  encoder_padding_mask=None,
                  incremental_state=None):
     if incremental_state is not None:
         prev_hiddens = encoder_out.mean(dim=0, keepdim=True)
         prev_cells = encoder_out.mean(dim=0, keepdim=True)
         utils.set_incremental_state(self, incremental_state,
                                     'cached_state',
                                     (prev_hiddens, prev_cells))
    def reorder_incremental_state(self, incremental_state, new_order):
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is None:
            return

        def reorder_state(state):
            if isinstance(state, list):
                return [reorder_state(state_i) for state_i in state]
            return state.index_select(0, new_order)

        if not isinstance(new_order, Variable):
            new_order = Variable(new_order)
        new_state = tuple(map(reorder_state, cached_state))
        utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
Beispiel #6
0
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # get outputs from encoder
        encoder_outs, _, _ = encoder_out
        srclen = encoder_outs.size(0)

        x = self.embed_tokens(prev_output_tokens)  # (bze, seqlen, embed_dim)
        x = F.dropout(x, p=self.dropout_in, training=self.training)
        embed_dim = x.size(2)

        x = x.transpose(0, 1)  # (seqlen, bsz, embed_dim)

        # initialize previous states (or get from cache during incremental generation)
        # cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')

        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            _, encoder_hiddens, encoder_cells = encoder_out
            num_layers = len(self.layers)
            prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
            prev_cells = [encoder_cells[i] for i in range(num_layers)]
            input_feed = Variable(x.data.new(bsz, embed_dim).zero_())

        attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_())
        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            input = torch.cat((x[j, :, :], input_feed), dim=1)

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                input = F.dropout(hidden,
                                  p=self.dropout_out,
                                  training=self.training)

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            # apply attention using the last layer's hidden state
            out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs)
            out = F.dropout(out, p=self.dropout_out, training=self.training)

            # input feeding
            input_feed = out

            # save final output
            outs.append(out)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(self, incremental_state, 'cached_state',
                                    (prev_hiddens, prev_cells, input_feed))

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)
        # T x B x C -> B x T x C
        x = x.transpose(1, 0)
        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        attn_scores = attn_scores.transpose(0, 2)

        x = self.fc_out(x)

        return x, attn_scores
 def _set_input_buffer(self, incremental_state, new_buffer):
     return utils.set_incremental_state(self, incremental_state,
                                        'input_buffer', new_buffer)
Beispiel #8
0
 def set_buffer(self, name, tensor, incremental_state):
     return utils.set_incremental_state(self, incremental_state, name,
                                        tensor)
    def forward(
        self,
        phase,
        epoch,
        fixed_max_len,
        prev_output_tokens,
        encoder_out,
        incremental_state=None,
    ):
        if phase == 'MLE':
            if incremental_state is not None:
                prev_output_tokens = prev_output_tokens[:, -1:]
            bsz, seqlen = prev_output_tokens.size()
            # print("generator.py LSTMDecoder forward", seqlen)

            # get outputs from encoder
            encoder_outs, _, _ = encoder_out
            srclen = encoder_outs.size(0)

            x1 = self.embed_tokens(
                prev_output_tokens)  # (bze, seqlen, embed_dim)
            x2 = F.dropout(x1, p=self.dropout_in, training=self.training)
            embed_dim = x2.size(2)

            x3 = x2.transpose(0, 1)  # (seqlen, bsz, embed_dim)
            x = x3.detach()
            # initialize previous states (or get from cache during incremental generation)
            # cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
            # initialize previous states (or get from cache during incremental generation)
            cached_state = utils.get_incremental_state(self, incremental_state,
                                                       'cached_state')

            if cached_state is not None:
                prev_hiddens, prev_cells, input_feed = cached_state
            else:
                _, encoder_hiddens, encoder_cells = encoder_out
                num_layers = len(self.layers)
                prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
                prev_cells = [encoder_cells[i] for i in range(num_layers)]
                input_feed = x.data.new(bsz, embed_dim).zero_()

            attn_scores = x.data.new(srclen, seqlen, bsz).zero_()
            outs = []
            p_list = []
            for j in range(fixed_max_len):
                # input feeding: concatenate context vector from previous time step
                # teacher forcing
                # input_feed 是decoder hidden结合encoder output的attention向量
                # x 是input (prev_output_tokens)长度
                # print('11111111111111111111111', x.size(),input_feed.size())
                input = torch.cat((x[j, :, :], input_feed), dim=1)
                for i, rnn in enumerate(self.layers):
                    # recurrent cell
                    hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
                    # hidden state becomes the input to the next layer
                    input = F.dropout(hidden,
                                      p=self.dropout_out,
                                      training=self.training)
                    # save state for next time step
                    prev_hiddens[i] = hidden
                    prev_cells[i] = cell
                    decoder_hidden = hidden

                # apply attention using the last layer's hidden state
                out, attn_scores[:,
                                 j, :] = self.attention(hidden, encoder_outs)
                out = F.dropout(out,
                                p=self.dropout_out,
                                training=self.training)
                # input feeding
                input_feed = out
                # save final output
                out_1 = out.unsqueeze(0)
                out_2 = out_1.transpose(1, 0)
                out_3 = self.fc_out(out_2)  # out_3 = [batch,1, num_vocab]
                outs.append(out_3)
                word = torch.argmax(out_3, dim=-1)  # word = [batch,1]
                out_4 = self.embed_tokens(word).squeeze(
                    1)  # word = [batch,dim]
                if j < fixed_max_len - 1:
                    p = self.calculate_p(epoch, x[j + 1, :, :], out_4)
                    is_teacher = random.random() > p
                    if not is_teacher:
                        x[j + 1, :, :] = out_4[:, :]

                    # cache previous states (no-op except during incremental generation)
            utils.set_incremental_state(self, incremental_state,
                                        'cached_state',
                                        (prev_hiddens, prev_cells, input_feed))

            attn_scores = attn_scores.transpose(0, 2)
            x = torch.cat(outs, dim=1).view(bsz, seqlen,
                                            -1)  # x = [batch,len,num_vocab]
            return x, attn_scores, p

        elif phase == 'PG':
            if incremental_state is not None:
                prev_output_tokens = prev_output_tokens[:, -1:]
            bsz, seqlen = prev_output_tokens.size()
            # print("generator.py LSTMDecoder forward", seqlen)

            # get outputs from encoder
            encoder_outs, _, _ = encoder_out
            srclen = encoder_outs.size(0)

            x = self.embed_tokens(
                prev_output_tokens)  # (bze, seqlen, embed_dim)
            x = F.dropout(x, p=self.dropout_in, training=self.training)
            embed_dim = x.size(2)
            x = x.transpose(0, 1)  # (seqlen, bsz, embed_dim)
            # initialize previous states (or get from cache during incremental generation)
            # cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
            # initialize previous states (or get from cache during incremental generation)
            cached_state = utils.get_incremental_state(self, incremental_state,
                                                       'cached_state')
            if cached_state is not None:
                prev_hiddens, prev_cells, input_feed = cached_state
            else:
                _, encoder_hiddens, encoder_cells = encoder_out
                num_layers = len(self.layers)
                prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
                prev_cells = [encoder_cells[i] for i in range(num_layers)]
                input_feed = x.data.new(bsz, embed_dim).zero_()

            attn_scores = x.data.new(srclen, seqlen, bsz).zero_()
            outs = []
            for j in range(fixed_max_len):
                # input feeding: concatenate context vector from previous time step
                # teacher forcing
                # input_feed 是decoder hidden结合encoder output的attention向量
                # x 是input (prev_output_tokens)长度
                # print('11111111111111111111111', x.size(),input_feed.size())
                input = torch.cat((x[j, :, :], input_feed), dim=1)
                for i, rnn in enumerate(self.layers):
                    # recurrent cell
                    hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
                    # hidden state becomes the input to the next layer
                    input = F.dropout(hidden,
                                      p=self.dropout_out,
                                      training=self.training)
                    # save state for next time step
                    prev_hiddens[i] = hidden
                    prev_cells[i] = cell
                # apply attention using the last layer's hidden state
                out, attn_scores[:,
                                 j, :] = self.attention(hidden, encoder_outs)
                out = F.dropout(out,
                                p=self.dropout_out,
                                training=self.training)
                # input feeding
                input_feed = out
                # save final output
                outs.append(out)
                if j < fixed_max_len - 1:
                    x[j + 1, :, :] = input[:, :]
            # cache previous states (no-op except during incremental generation)
            utils.set_incremental_state(self, incremental_state,
                                        'cached_state',
                                        (prev_hiddens, prev_cells, input_feed))

            # collect outputs across time steps
            x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)
            # T x B x C -> B x T x C
            x = x.transpose(1, 0)
            # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
            attn_scores = attn_scores.transpose(0, 2)

            x = self.fc_out(x)
            p = 0
            return x, attn_scores, p

        elif phase == 'test':
            if incremental_state is not None:
                prev_output_tokens = prev_output_tokens[:, -1:]
            bsz, seqlen = prev_output_tokens.size()

            # get outputs from encoder
            encoder_outs, _, _ = encoder_out
            srclen = encoder_outs.size(0)

            x = self.embed_tokens(
                prev_output_tokens)  # (bze, seqlen, embed_dim)
            x = F.dropout(x, p=self.dropout_in, training=self.training)
            embed_dim = x.size(2)

            x = x.transpose(0, 1)  # (seqlen, bsz, embed_dim)

            # initialize previous states (or get from cache during incremental generation)
            # cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
            # initialize previous states (or get from cache during incremental generation)
            cached_state = utils.get_incremental_state(self, incremental_state,
                                                       'cached_state')

            if cached_state is not None:
                prev_hiddens, prev_cells, input_feed = cached_state
            else:
                _, encoder_hiddens, encoder_cells = encoder_out
                num_layers = len(self.layers)
                prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
                prev_cells = [encoder_cells[i] for i in range(num_layers)]
                input_feed = x.data.new(bsz, embed_dim).zero_()

            attn_scores = x.data.new(srclen, seqlen, bsz).zero_()
            outs = []
            for j in range(seqlen):
                # input feeding: concatenate context vector from previous time step
                # teacher forcing
                # input_feed 是decoder hidden结合encoder output的attention向量
                # x 是input (prev_output_tokens)长度
                input = torch.cat((x[j, :, :], input_feed), dim=1)

                for i, rnn in enumerate(self.layers):
                    # recurrent cell
                    hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                    # hidden state becomes the input to the next layer
                    input = F.dropout(hidden,
                                      p=self.dropout_out,
                                      training=self.training)

                    # save state for next time step
                    prev_hiddens[i] = hidden
                    prev_cells[i] = cell

                # apply attention using the last layer's hidden state
                out, attn_scores[:,
                                 j, :] = self.attention(hidden, encoder_outs)
                out = F.dropout(out,
                                p=self.dropout_out,
                                training=self.training)

                # input feeding
                input_feed = out

                # save final output
                outs.append(out)

            # cache previous states (no-op except during incremental generation)
            utils.set_incremental_state(self, incremental_state,
                                        'cached_state',
                                        (prev_hiddens, prev_cells, input_feed))

            # collect outputs across time steps
            x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)
            # T x B x C -> B x T x C
            x = x.transpose(1, 0)
            # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
            attn_scores = attn_scores.transpose(0, 2)

            x = self.fc_out(x)
            p = 0
            return x, attn_scores, p