Exemple #1
0
    def forward(self,
                x,
                encoder_out=None,
                encoder_padding_mask=None,
                incremental_state=None,
                **kwargs):
        layer_norm_training = kwargs.get('layer_norm_training', None)
        if layer_norm_training is not None:
            self.layer_norm1.training = layer_norm_training
            self.layer_norm2.training = layer_norm_training
        self.lstm.flatten_parameters()
        if incremental_state is not None:
            x = x[-1:, :, :]
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells = cached_state
        else:
            prev_hiddens = encoder_out.mean(dim=0, keepdim=True)
            prev_cells = encoder_out.mean(dim=0, keepdim=True)

        residual = x
        x = self.layer_norm1(x)
        x, hidden = self.lstm(x, (prev_hiddens, prev_cells))
        hiddens, cells = hidden
        x = residual + x

        x = self.layer_norm2(x)
        x, attn = self.attention(
            query=x,
            key=encoder_out,
            value=encoder_out,
            key_padding_mask=encoder_padding_mask,
            incremental_state=incremental_state,
            static_kv=True,
            enc_dec_attn_constraint_mask=utils.get_incremental_state(
                self, incremental_state, 'enc_dec_attn_constraint_mask'))
        x = F.dropout(x, self.dropout, training=self.training)

        if incremental_state is not None:
            #prev_hiddens = torch.cat((prev_hiddens, hiddens), dim=0)
            #prev_cells = torch.cat((prev_cells, cells), dim=0)
            prev_hiddens = hiddens
            prev_cells = cells
            utils.set_incremental_state(
                self,
                incremental_state,
                'cached_state',
                (prev_hiddens, prev_cells),
            )

        x = residual + x
        attn_logits = attn[1]
        #if len(attn_logits.size()) > 3:
        #    attn_logits = attn_logits[:, 0]
        return x, attn_logits
Exemple #2
0
 def forward(self,
             x,
             encoder_out=None,
             encoder_padding_mask=None,
             incremental_state=None,
             **kwargs):
     layer_norm_training = kwargs.get('layer_norm_training', None)
     if layer_norm_training is not None:
         self.layer_norm1.training = layer_norm_training
         self.layer_norm2.training = layer_norm_training
     residual = x
     x = self.layer_norm1(x)
     x = self.conv(x, incremental_state=incremental_state)
     x = F.relu(x)
     x = F.dropout(x, self.dropout, training=self.training)
     x = residual + x
     x = self.layer_norm2(x)
     x, attn = self.attention(
         query=x,
         key=encoder_out,
         value=encoder_out,
         key_padding_mask=encoder_padding_mask,
         incremental_state=incremental_state,
         static_kv=True,
         enc_dec_attn_constraint_mask=utils.get_incremental_state(
             self, incremental_state, 'enc_dec_attn_constraint_mask'))
     x = F.dropout(x, self.dropout, training=self.training)
     x = residual + x
     attn_logits = attn[1]
     #if len(attn_logits.size()) > 3:
     #    attn_logits = attn_logits[:, 0]
     return x, attn_logits
Exemple #3
0
    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
        **kwargs,
    ):
        layer_norm_training = kwargs.get('layer_norm_training', None)
        if layer_norm_training is not None:
            self.layer_norm1.training = layer_norm_training
            self.layer_norm2.training = layer_norm_training
            self.layer_norm3.training = layer_norm_training
        residual = x
        x = self.layer_norm1(x)
        x, _ = self.self_attn(query=x,
                              key=x,
                              value=x,
                              key_padding_mask=self_attn_padding_mask,
                              incremental_state=incremental_state,
                              attn_mask=self_attn_mask)
        x = F.dropout(x, self.dropout, training=self.training)
        x = residual + x

        residual = x
        x = self.layer_norm2(x)
        x, attn = self.encoder_attn(
            query=x,
            key=encoder_out,
            value=encoder_out,
            key_padding_mask=encoder_padding_mask,
            incremental_state=incremental_state,
            static_kv=True,
            enc_dec_attn_constraint_mask=utils.get_incremental_state(
                self, incremental_state, 'enc_dec_attn_constraint_mask'))
        x = F.dropout(x, self.dropout, training=self.training)
        x = residual + x

        residual = x
        x = self.layer_norm3(x)
        x = self.ffn(x, incremental_state=incremental_state)
        x = F.dropout(x, self.dropout, training=self.training)
        x = residual + x
        attn_logits = attn[1]
        #if len(attn_logits.size()) > 3:
        #    indices = attn_logits.softmax(-1).max(-1).values.sum(-1).argmax(-1)
        #    attn_logits = attn_logits.gather(1,
        #        indices[:, None, None, None].repeat(1, 1, attn_logits.size(-2), attn_logits.size(-1))).squeeze(1)
        return x, attn_logits
    def reorder_incremental_state(self, incremental_state, new_order):
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is None:
            return

        def reorder_state(state):
            if isinstance(state, list):
                return [reorder_state(state_i) for state_i in state]
            return state.index_select(0, new_order)

        if not isinstance(new_order, Variable):
            new_order = Variable(new_order)
        new_state = tuple(map(reorder_state, cached_state))
        utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
 def _get_input_buffer(self, incremental_state):
     return get_incremental_state(
         self,
         incremental_state,
         'attn_state',
     ) or {}
Exemple #6
0
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # get outputs from encoder
        encoder_outs, _, _ = encoder_out
        srclen = encoder_outs.size(0)

        x = self.embed_tokens(prev_output_tokens)  # (bze, seqlen, embed_dim)
        x = F.dropout(x, p=self.dropout_in, training=self.training)
        embed_dim = x.size(2)

        x = x.transpose(0, 1)  # (seqlen, bsz, embed_dim)

        # initialize previous states (or get from cache during incremental generation)
        # cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')

        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            _, encoder_hiddens, encoder_cells = encoder_out
            num_layers = len(self.layers)
            prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
            prev_cells = [encoder_cells[i] for i in range(num_layers)]
            input_feed = Variable(x.data.new(bsz, embed_dim).zero_())

        attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_())
        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            input = torch.cat((x[j, :, :], input_feed), dim=1)

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                input = F.dropout(hidden,
                                  p=self.dropout_out,
                                  training=self.training)

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            # apply attention using the last layer's hidden state
            out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs)
            out = F.dropout(out, p=self.dropout_out, training=self.training)

            # input feeding
            input_feed = out

            # save final output
            outs.append(out)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(self, incremental_state, 'cached_state',
                                    (prev_hiddens, prev_cells, input_feed))

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)
        # T x B x C -> B x T x C
        x = x.transpose(1, 0)
        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        attn_scores = attn_scores.transpose(0, 2)

        x = self.fc_out(x)

        return x, attn_scores
 def _get_input_buffer(self, incremental_state):
     return utils.get_incremental_state(self, incremental_state,
                                        'input_buffer')
    def forward(
        self,
        phase,
        epoch,
        fixed_max_len,
        prev_output_tokens,
        encoder_out,
        incremental_state=None,
    ):
        if phase == 'MLE':
            if incremental_state is not None:
                prev_output_tokens = prev_output_tokens[:, -1:]
            bsz, seqlen = prev_output_tokens.size()
            # print("generator.py LSTMDecoder forward", seqlen)

            # get outputs from encoder
            encoder_outs, _, _ = encoder_out
            srclen = encoder_outs.size(0)

            x1 = self.embed_tokens(
                prev_output_tokens)  # (bze, seqlen, embed_dim)
            x2 = F.dropout(x1, p=self.dropout_in, training=self.training)
            embed_dim = x2.size(2)

            x3 = x2.transpose(0, 1)  # (seqlen, bsz, embed_dim)
            x = x3.detach()
            # initialize previous states (or get from cache during incremental generation)
            # cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
            # initialize previous states (or get from cache during incremental generation)
            cached_state = utils.get_incremental_state(self, incremental_state,
                                                       'cached_state')

            if cached_state is not None:
                prev_hiddens, prev_cells, input_feed = cached_state
            else:
                _, encoder_hiddens, encoder_cells = encoder_out
                num_layers = len(self.layers)
                prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
                prev_cells = [encoder_cells[i] for i in range(num_layers)]
                input_feed = x.data.new(bsz, embed_dim).zero_()

            attn_scores = x.data.new(srclen, seqlen, bsz).zero_()
            outs = []
            p_list = []
            for j in range(fixed_max_len):
                # input feeding: concatenate context vector from previous time step
                # teacher forcing
                # input_feed 是decoder hidden结合encoder output的attention向量
                # x 是input (prev_output_tokens)长度
                # print('11111111111111111111111', x.size(),input_feed.size())
                input = torch.cat((x[j, :, :], input_feed), dim=1)
                for i, rnn in enumerate(self.layers):
                    # recurrent cell
                    hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
                    # hidden state becomes the input to the next layer
                    input = F.dropout(hidden,
                                      p=self.dropout_out,
                                      training=self.training)
                    # save state for next time step
                    prev_hiddens[i] = hidden
                    prev_cells[i] = cell
                    decoder_hidden = hidden

                # apply attention using the last layer's hidden state
                out, attn_scores[:,
                                 j, :] = self.attention(hidden, encoder_outs)
                out = F.dropout(out,
                                p=self.dropout_out,
                                training=self.training)
                # input feeding
                input_feed = out
                # save final output
                out_1 = out.unsqueeze(0)
                out_2 = out_1.transpose(1, 0)
                out_3 = self.fc_out(out_2)  # out_3 = [batch,1, num_vocab]
                outs.append(out_3)
                word = torch.argmax(out_3, dim=-1)  # word = [batch,1]
                out_4 = self.embed_tokens(word).squeeze(
                    1)  # word = [batch,dim]
                if j < fixed_max_len - 1:
                    p = self.calculate_p(epoch, x[j + 1, :, :], out_4)
                    is_teacher = random.random() > p
                    if not is_teacher:
                        x[j + 1, :, :] = out_4[:, :]

                    # cache previous states (no-op except during incremental generation)
            utils.set_incremental_state(self, incremental_state,
                                        'cached_state',
                                        (prev_hiddens, prev_cells, input_feed))

            attn_scores = attn_scores.transpose(0, 2)
            x = torch.cat(outs, dim=1).view(bsz, seqlen,
                                            -1)  # x = [batch,len,num_vocab]
            return x, attn_scores, p

        elif phase == 'PG':
            if incremental_state is not None:
                prev_output_tokens = prev_output_tokens[:, -1:]
            bsz, seqlen = prev_output_tokens.size()
            # print("generator.py LSTMDecoder forward", seqlen)

            # get outputs from encoder
            encoder_outs, _, _ = encoder_out
            srclen = encoder_outs.size(0)

            x = self.embed_tokens(
                prev_output_tokens)  # (bze, seqlen, embed_dim)
            x = F.dropout(x, p=self.dropout_in, training=self.training)
            embed_dim = x.size(2)
            x = x.transpose(0, 1)  # (seqlen, bsz, embed_dim)
            # initialize previous states (or get from cache during incremental generation)
            # cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
            # initialize previous states (or get from cache during incremental generation)
            cached_state = utils.get_incremental_state(self, incremental_state,
                                                       'cached_state')
            if cached_state is not None:
                prev_hiddens, prev_cells, input_feed = cached_state
            else:
                _, encoder_hiddens, encoder_cells = encoder_out
                num_layers = len(self.layers)
                prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
                prev_cells = [encoder_cells[i] for i in range(num_layers)]
                input_feed = x.data.new(bsz, embed_dim).zero_()

            attn_scores = x.data.new(srclen, seqlen, bsz).zero_()
            outs = []
            for j in range(fixed_max_len):
                # input feeding: concatenate context vector from previous time step
                # teacher forcing
                # input_feed 是decoder hidden结合encoder output的attention向量
                # x 是input (prev_output_tokens)长度
                # print('11111111111111111111111', x.size(),input_feed.size())
                input = torch.cat((x[j, :, :], input_feed), dim=1)
                for i, rnn in enumerate(self.layers):
                    # recurrent cell
                    hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
                    # hidden state becomes the input to the next layer
                    input = F.dropout(hidden,
                                      p=self.dropout_out,
                                      training=self.training)
                    # save state for next time step
                    prev_hiddens[i] = hidden
                    prev_cells[i] = cell
                # apply attention using the last layer's hidden state
                out, attn_scores[:,
                                 j, :] = self.attention(hidden, encoder_outs)
                out = F.dropout(out,
                                p=self.dropout_out,
                                training=self.training)
                # input feeding
                input_feed = out
                # save final output
                outs.append(out)
                if j < fixed_max_len - 1:
                    x[j + 1, :, :] = input[:, :]
            # cache previous states (no-op except during incremental generation)
            utils.set_incremental_state(self, incremental_state,
                                        'cached_state',
                                        (prev_hiddens, prev_cells, input_feed))

            # collect outputs across time steps
            x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)
            # T x B x C -> B x T x C
            x = x.transpose(1, 0)
            # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
            attn_scores = attn_scores.transpose(0, 2)

            x = self.fc_out(x)
            p = 0
            return x, attn_scores, p

        elif phase == 'test':
            if incremental_state is not None:
                prev_output_tokens = prev_output_tokens[:, -1:]
            bsz, seqlen = prev_output_tokens.size()

            # get outputs from encoder
            encoder_outs, _, _ = encoder_out
            srclen = encoder_outs.size(0)

            x = self.embed_tokens(
                prev_output_tokens)  # (bze, seqlen, embed_dim)
            x = F.dropout(x, p=self.dropout_in, training=self.training)
            embed_dim = x.size(2)

            x = x.transpose(0, 1)  # (seqlen, bsz, embed_dim)

            # initialize previous states (or get from cache during incremental generation)
            # cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
            # initialize previous states (or get from cache during incremental generation)
            cached_state = utils.get_incremental_state(self, incremental_state,
                                                       'cached_state')

            if cached_state is not None:
                prev_hiddens, prev_cells, input_feed = cached_state
            else:
                _, encoder_hiddens, encoder_cells = encoder_out
                num_layers = len(self.layers)
                prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
                prev_cells = [encoder_cells[i] for i in range(num_layers)]
                input_feed = x.data.new(bsz, embed_dim).zero_()

            attn_scores = x.data.new(srclen, seqlen, bsz).zero_()
            outs = []
            for j in range(seqlen):
                # input feeding: concatenate context vector from previous time step
                # teacher forcing
                # input_feed 是decoder hidden结合encoder output的attention向量
                # x 是input (prev_output_tokens)长度
                input = torch.cat((x[j, :, :], input_feed), dim=1)

                for i, rnn in enumerate(self.layers):
                    # recurrent cell
                    hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                    # hidden state becomes the input to the next layer
                    input = F.dropout(hidden,
                                      p=self.dropout_out,
                                      training=self.training)

                    # save state for next time step
                    prev_hiddens[i] = hidden
                    prev_cells[i] = cell

                # apply attention using the last layer's hidden state
                out, attn_scores[:,
                                 j, :] = self.attention(hidden, encoder_outs)
                out = F.dropout(out,
                                p=self.dropout_out,
                                training=self.training)

                # input feeding
                input_feed = out

                # save final output
                outs.append(out)

            # cache previous states (no-op except during incremental generation)
            utils.set_incremental_state(self, incremental_state,
                                        'cached_state',
                                        (prev_hiddens, prev_cells, input_feed))

            # collect outputs across time steps
            x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)
            # T x B x C -> B x T x C
            x = x.transpose(1, 0)
            # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
            attn_scores = attn_scores.transpose(0, 2)

            x = self.fc_out(x)
            p = 0
            return x, attn_scores, p