Beispiel #1
0
 def _set_input_buffer(self, incremental_state, buffer):
     utils.set_incremental_state(
         self,
         incremental_state,
         'attn_state',
         buffer,
     )
Beispiel #2
0
 def _set_input_buffer(self,
                       incremental_state,
                       buffer,
                       incremental_clone_id=""):
     self.incremental_clone_ids.add(incremental_clone_id)
     utils.set_incremental_state(self, incremental_state,
                                 "attn_state" + incremental_clone_id,
                                 buffer)
Beispiel #3
0
    def forward(self,
                x,
                incremental_state=None,
                encoder_lstm_states=None,
                **unused):

        residual = x
        x = self.maybe_layer_norm(x, before=True)
        seqlen, bsz, _ = x.size()

        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            num_layers = len(self.layers)
            if encoder_lstm_states is not None:
                encoder_hiddens, encoder_cells = encoder_lstm_states
                prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
                prev_cells = [encoder_cells[i] for i in range(num_layers)]
            else:
                state_size = bsz, self.hidden_dim
                prev_hiddens = [
                    x.new_zeros(*state_size) for _ in range(num_layers)
                ]
                prev_cells = [
                    x.new_zeros(*state_size) for _ in range(num_layers)
                ]
                # prev_hiddens = [x.new_zeros(*state_size) for i in range(num_layers)]
                # prev_cells = [x.new_zeros(*state_size) for i in range(num_layers)]
            input_feed = x.new_zeros(bsz, self.hidden_dim)

        outs = []
        for j in range(seqlen):
            input = torch.cat((x[j, :, :], input_feed), dim=1)

            for i, rnn in enumerate(self.layers):

                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
                # input = F.dropout(hidden, p=self.dropout, training=self.training)
                input = hidden
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            out = hidden
            input_feed = out
            outs.append(out)

        utils.set_incremental_state(self, incremental_state, 'cached_state',
                                    (prev_hiddens, prev_cells, input_feed))

        x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_dim)
        x = self.linear(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = x + residual
        x = self.maybe_layer_norm(x, after=True)

        return x, None
Beispiel #4
0
 def reorder_incremental_state(self, incremental_state, new_order):
     super().reorder_incremental_state(incremental_state, new_order)
     encoder_out = utils.get_incremental_state(self, incremental_state,
                                               'encoder_out')
     if encoder_out is not None:
         encoder_out = tuple(
             eo.index_select(0, new_order) for eo in encoder_out)
         utils.set_incremental_state(self, incremental_state, 'encoder_out',
                                     encoder_out)
Beispiel #5
0
    def forward_unprojected(self,
                            input_tokens,
                            encoder_out,
                            incremental_state=None):
        padded_tokens = F.pad(
            input_tokens,
            (self.history_len - 1, 0, 0, 0),
            "constant",
            self.dst_dict.eos(),
        )
        # We use incremental_state only to check whether we are decoding or not
        # self.training is false even for the forward pass through validation
        if incremental_state is not None:
            padded_tokens = padded_tokens[:, -self.history_len:]
        utils.set_incremental_state(self, incremental_state,
                                    "incremental_marker", True)

        bsz, seqlen = padded_tokens.size()
        seqlen -= self.history_len - 1

        # get outputs from encoder
        (encoder_outs, final_hidden, _, src_lengths, _) = encoder_out

        # padded_tokens has shape [batch_size, seq_len+history_len]
        x = self.embed_tokens(padded_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # Convolution needs shape [batch_size, channels, seq_len]
        x = self.history_conv(x.transpose(1, 2)).transpose(1, 2)
        x = F.dropout(x, p=self.dropout_out, training=self.training)

        # x has shape [batch_size, seq_len, channels]
        for i, layer in enumerate(self.layers):
            prev_x = x
            x = layer(x)
            x = F.dropout(x, p=self.dropout_out, training=self.training)
            if self.residual_level is not None and i >= self.residual_level:
                x = x + prev_x

        # Attention
        attn_out, attn_scores = self.attention(
            x.transpose(0, 1).contiguous().view(-1, self.hidden_dim),
            encoder_outs.repeat(1, seqlen, 1),
            src_lengths.repeat(seqlen),
        )
        if attn_out is not None:
            attn_out = attn_out.view(seqlen, bsz, -1).transpose(1, 0)
        attn_scores = attn_scores.view(-1, seqlen, bsz).transpose(0, 2)
        x = maybe_cat((x, attn_out), dim=2)

        # bottleneck layer
        if hasattr(self, "additional_fc"):
            x = self.additional_fc(x)
            x = F.dropout(x, p=self.dropout_out, training=self.training)
        return x, attn_scores
Beispiel #6
0
    def reorder_incremental_state(self, incremental_state, new_order):
        super().reorder_incremental_state(incremental_state, new_order)
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is None:
            return

        def reorder_state(state):
            if isinstance(state, list):
                return [reorder_state(state_i) for state_i in state]
            return state.index_select(0, new_order) if state is not None else None

        new_state = tuple(map(reorder_state, cached_state))
        utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
Beispiel #7
0
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        encoder_padding_mask = encoder_out['encoder_padding_mask'].t()
        encoder_out = encoder_out['encoder_out']

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]

        bsz, seqlen = prev_output_tokens.size()
        srclen = encoder_out.size(0)
        
        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, _ = cached_state
        else:         
            prev_hiddens = [encoder_out.mean(dim=0) for i in range(self.num_layers)]
            prev_cells = [encoder_out.mean(dim=0) for i in range(self.num_layers)]

        # get outputs from encoder
        #encoder_out = self.layer_norm(encoder_out)
        x = torch.stack([encoder_out[:,i,:].index_select(0, prev_output_tokens[i])
                        for i in range(encoder_out.size(1))], dim=1)
        x = self.layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        #encoder_out = F.dropout(encoder_out, p=self.dropout, training=self.training)

        attn_scores = x.new_zeros(bsz, seqlen, srclen)
        for j in range(seqlen):
            input = x[j, :, :]

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                input = F.dropout(hidden, p=self.dropout, training=self.training)

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            # apply attention using the last layer's hidden state
            attn_scores[:, j, :] = self.attention(hidden, encoder_out, encoder_padding_mask)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self, incremental_state, 'cached_state',
            (prev_hiddens, prev_cells, None),
        )

        return attn_scores, None
Beispiel #8
0
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        if incremental_state is not None:
            # If the *incremental_state* argument is not ``None`` then we are
            # in incremental inference mode. While *prev_output_tokens* will
            # still contain the entire decoded prefix, we will only use the
            # last step and assume that the rest of the state is cached.
            prev_output_tokens = prev_output_tokens[:, -1:]

        # This remains the same as before.
        bsz, tgt_len = prev_output_tokens.size()
        final_encoder_hidden = encoder_out['final_hidden']
        final_encoder_hidden = final_encoder_hidden[0:bsz, :]
        x = self.embed_tokens(prev_output_tokens)
        x = self.dropout(x)
        x = torch.cat(
            [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
            dim=2,
        )

        # We will now check the cache and load the cached previous hidden and
        # cell states, if they exist, otherwise we will initialize them to
        # zeros (as before). We will use the ``utils.get_incremental_state()``
        # and ``utils.set_incremental_state()`` helpers.
        initial_state = utils.get_incremental_state(
            self,
            incremental_state,
            'prev_state',
        )
        if initial_state is None:
            # first time initialization, same as the original version
            initial_state = (
                final_encoder_hidden.unsqueeze(0),  # hidden
                torch.zeros_like(final_encoder_hidden).unsqueeze(0),  # cell
            )

        # Run one step of our LSTM.
        output, latest_state = self.lstm(x.transpose(0, 1), initial_state)

        # Update the cache with the latest hidden and cell states.
        utils.set_incremental_state(
            self,
            incremental_state,
            'prev_state',
            latest_state,
        )

        # This remains the same as before
        x = output.transpose(0, 1)
        x = self.output_projection(x)
        return x, None
Beispiel #9
0
    def reorder_incremental_state(self, incremental_state, new_order):
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is None:
            return

        def reorder_state(state):
            if isinstance(state, list):
                return [reorder_state(state_i) for state_i in state]
            return state.index_select(0, new_order)

        if not isinstance(new_order, Variable):
            new_order = Variable(new_order)
        new_state = tuple(map(reorder_state, cached_state))
        utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
    def set_pointer(self, incremental_state, p_choose):
        curr_pointer = self.get_pointer(incremental_state)
        if len(curr_pointer) == 0:
            buffer = torch.zeros_like(p_choose)
        else:
            buffer = self.get_pointer(incremental_state)["step"]

        buffer += (p_choose < 0.5).type_as(buffer)

        utils.set_incremental_state(
            self,
            incremental_state,
            'monotonic',
            {"step": buffer},
        )
Beispiel #11
0
    def reorder_incremental_state(self, incremental_state, new_order):
        """Reorder buffered internal state (for incremental generation)."""
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   "cached_state")
        if cached_state is None:
            return

        def reorder_state(state):
            if isinstance(state, list):
                return [reorder_state(state_i) for state_i in state]
            return state.index_select(0, new_order)

        new_state = tuple(map(reorder_state, cached_state))
        utils.set_incremental_state(self, incremental_state, "cached_state",
                                    new_state)
    def reorder_incremental_state(self, incremental_state, new_order):
        super().reorder_incremental_state(incremental_state, new_order)

        cumsum_probs = utils.get_incremental_state(self, incremental_state,
                                                   'cumsum_probs')
        if cumsum_probs is not None:
            new_cumsum_probs = cumsum_probs.index_select(0, new_order)
            utils.set_incremental_state(self, incremental_state,
                                        'cumsum_probs', new_cumsum_probs)

        nodes = utils.get_incremental_state(self, incremental_state, 'nodes')
        if nodes is not None:
            new_nodes = nodes.index_select(0, new_order)
            utils.set_incremental_state(self, incremental_state, 'nodes',
                                        new_nodes)
Beispiel #13
0
 def reorder_incremental_state(self, incremental_state, new_order):
     cached_state = utils.get_incremental_state(self, incremental_state,
                                                'cached_state')
     if cached_state is not None:
         prev_hiddens, prev_cells, input_feed = cached_state
         prev_hiddens = [
             hidden.index_select(0, new_order) for hidden in prev_hiddens
         ]
         prev_cells = [
             cell.index_select(0, new_order) for cell in prev_cells
         ]
         input_feed = input_feed.index_select(0, new_order)
         utils.set_incremental_state(self, incremental_state,
                                     'cached_state',
                                     (prev_hiddens, prev_cells, input_feed))
Beispiel #14
0
    def reorder_incremental_state(self, incremental_state, new_order):
        # Load the cached state.
        prev_state = utils.get_incremental_state(
            self, incremental_state, 'prev_state',
        )

        # Reorder batches according to *new_order*.
        reordered_state = (
            prev_state[0].index_select(1, new_order),  # hidden
            prev_state[1].index_select(1, new_order),  # cell
        )

        # Update the cached state.
        utils.set_incremental_state(
            self, incremental_state, 'prev_state', reordered_state,
        )
Beispiel #15
0
 def _set_input_buffer(
     self,
     incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
     new_buffer,
 ):
     return utils.set_incremental_state(self, incremental_state,
                                        "input_buffer", new_buffer)
    def reorder_incremental_state(self, incremental_state, new_order):
        super().reorder_incremental_state(incremental_state, new_order)

        cumsum_probs = utils.get_incremental_state(self, incremental_state,
                                                   'cumsum_probs')
        if cumsum_probs is not None:
            new_cumsum_probs = cumsum_probs.index_select(0, new_order)
            utils.set_incremental_state(self, incremental_state,
                                        'cumsum_probs', new_cumsum_probs)

        nodes = utils.get_incremental_state(self, incremental_state, 'nodes')
        if nodes is not None:
            new_order_list = new_order.tolist()
            new_nodes = [nodes[i] for i in new_order_list]
            utils.set_incremental_state(self, incremental_state, 'nodes',
                                        new_nodes)
Beispiel #17
0
    def reorder_incremental_state(self, incremental_state, new_order):
        # parent reorders attention model
        super().reorder_incremental_state(incremental_state, new_order)

        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   "cached_state")
        if cached_state is None:
            return

        # Last 2 elements of prev_states are encoder projections
        # used for ONNX export
        for i, state in enumerate(cached_state[:-2]):
            cached_state[i] = state.index_select(1, new_order)

        utils.set_incremental_state(self, incremental_state, "cached_state",
                                    cached_state)
Beispiel #18
0
    def _split_encoder_out(self, encoder_out, incremental_state):
        """Split and transpose encoder outputs.

        This is cached when doing incremental inference.
        """
        cached_result = utils.get_incremental_state(self, incremental_state, 'encoder_out')
        if cached_result is not None:
            return cached_result

        # transpose only once to speed up attention layers
        encoder_a, encoder_b = encoder_out
        encoder_a = encoder_a.transpose(1, 2).contiguous()
        result = (encoder_a, encoder_b)

        if incremental_state is not None:
            utils.set_incremental_state(self, incremental_state, 'encoder_out', result)
        return result
    def reorder_incremental_state(self, incremental_state, new_order):
        super().reorder_incremental_state(incremental_state, new_order)

        for state_name in ['wordlm_logprobs', 'out_logprobs', 'subword_cumlogprobs']:
            state = utils.get_incremental_state(self, incremental_state, state_name)
            if state is not None:
                new_state = state.index_select(0, new_order)
                utils.set_incremental_state(
                    self, incremental_state, state_name, new_state,
                )

        nodes = utils.get_incremental_state(self, incremental_state, 'nodes')
        if nodes is not None:
            new_order_list = new_order.tolist()
            new_nodes = [nodes[i] for i in new_order_list]
            utils.set_incremental_state(
                self, incremental_state, 'nodes', new_nodes,
            )
Beispiel #20
0
    def forward(self,
                prev_output_tokens,
                encoder_out=None,
                incremental_state=None):
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bbsz = prev_output_tokens.size(0)
        vocab = len(self.dictionary)
        src_len = encoder_out.encoder_out.size(1)
        tgt_len = prev_output_tokens.size(1)

        # determine number of steps
        if incremental_state is not None:
            # cache step number
            step = utils.get_incremental_state(self, incremental_state, "step")
            if step is None:
                step = 0
            utils.set_incremental_state(self, incremental_state, "step",
                                        step + 1)
            steps = [step]
        else:
            steps = list(range(tgt_len))

        # define output in terms of raw probs
        if hasattr(self.args, "probs"):
            assert (self.args.probs.dim() == 3
                    ), "expected probs to have size bsz*steps*vocab"
            probs = self.args.probs.index_select(1, torch.LongTensor(steps))
        else:
            probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
            for i, step in enumerate(steps):
                # args.beam_probs gives the probability for every vocab element,
                # starting with eos, then unknown, and then the rest of the vocab
                if step < len(self.args.beam_probs):
                    probs[:, i,
                          self.dictionary.eos():] = self.args.beam_probs[step]
                else:
                    probs[:, i, self.dictionary.eos()] = 1.0

        # random attention
        attn = torch.rand(bbsz, tgt_len, src_len)

        dev = prev_output_tokens.device
        return probs.to(dev), {"attn": [attn.to(dev)]}
Beispiel #21
0
    def reorder_incremental_state(self, incremental_state, new_order):
        super().reorder_incremental_state(incremental_state, new_order)
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')
        if cached_state is None:
            return

        #EDITED
        def reorder_state(state, idx):
            if isinstance(state, list) or isinstance(state, tuple):
                return [reorder_state(state_i, idx) for state_i in state]
            return state.index_select(idx, new_order)

        new_state = [
            reorder_state(sub, idx)
            for (sub, idx) in zip(cached_state, [0, 0, 0, 1])
        ]
        utils.set_incremental_state(self, incremental_state, 'cached_state',
                                    new_state)
Beispiel #22
0
    def reorder_incremental_state(self, incremental_state, new_order):
        """
        The ``FairseqIncrementalDecoder`` interface also requires implementing a
        ``reorder_incremental_state()`` method, which is used during beam search
        to select and reorder the incremental state.
        """
        # Load the cached state.
        prev_state = utils.get_incremental_state(
            self, incremental_state, 'prev_state',
        )

        # Reorder batches according to *new_order*.
        reordered_state = (
            prev_state[0].index_select(1, new_order),  # hidden
            prev_state[1].index_select(1, new_order),  # cell
        )

        # Update the cached state.
        utils.set_incremental_state(
            self, incremental_state, 'prev_state', reordered_state,
        )
Beispiel #23
0
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bbsz = prev_output_tokens.size(0)
        vocab = len(self.dictionary)
        src_len = encoder_out.size(1)
        tgt_len = prev_output_tokens.size(1)

        # determine number of steps
        if incremental_state is not None:
            # cache step number
            step = utils.get_incremental_state(self, incremental_state, 'step')
            if step is None:
                step = 0
            utils.set_incremental_state(self, incremental_state, 'step', step + 1)
            steps = [step]
        else:
            steps = list(range(tgt_len))

        # define output in terms of raw probs
        if hasattr(self.args, 'probs'):
            assert self.args.probs.dim() == 3, \
                'expected probs to have size bsz*steps*vocab'
            probs = self.args.probs.index_select(1, torch.LongTensor(steps))
        else:
            probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
            for i, step in enumerate(steps):
                # args.beam_probs gives the probability for every vocab element,
                # starting with eos, then unknown, and then the rest of the vocab
                if step < len(self.args.beam_probs):
                    probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
                else:
                    probs[:, i, self.dictionary.eos()] = 1.0

        # random attention
        attn = torch.rand(bbsz, tgt_len, src_len)

        return probs, attn
    def reorder_incremental_state(self, incremental_state, new_order):

        def apply_reorder_incremental_state(module):
            if module != self and hasattr(module, 'reorder_incremental_state'):
                module.reorder_incremental_state(
                    incremental_state,
                    new_order,
                )
        self.apply(apply_reorder_incremental_state)

        # document decoder
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state_rnn')
        if cached_state is None:
            return

        def reorder_state(state):
            if isinstance(state, list):
                return [reorder_state(state_i) for state_i in state]
            return state.index_select(0, new_order)

        new_state = tuple(map(reorder_state, cached_state))

        utils.set_incremental_state(self, incremental_state, 'cached_state_rnn', new_state)
Beispiel #25
0
    def masked_copy_incremental_state(self, incremental_state, another_state, mask):
        state = utils.get_incremental_state(self, incremental_state, 'encoder_out')
        if state is None:
            assert another_state is None
            return

        def mask_copy_state(state, another_state):
            if isinstance(state, list):
                assert isinstance(another_state, list) and len(state) == len(another_state)
                return [mask_copy_state(state_i, another_state_i) \
                    for state_i, another_state_i in zip(state, another_state)]
            if state is not None:
                assert state.size(0) == mask.size(0) and another_state is not None and \
                    state.size() == another_state.size()
                for _ in range(1, len(state.size())):
                    mask_unsqueezed = mask.unsqueeze(-1)
                return torch.where(mask_unsqueezed, state, another_state)
            else:
                assert another_state is None
                return None

        new_state = tuple(map(mask_copy_state, state, another_state))
        utils.set_incremental_state(self, incremental_state, 'encoder_out', new_state)
Beispiel #26
0
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # get outputs from encoder
        encoder_outs, _, _ = encoder_out[:3]
        srclen = encoder_outs.size(0)

        # embed tokens
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            _, encoder_hiddens, encoder_cells = encoder_out[:3]
            num_layers = len(self.layers)
            prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
            prev_cells = [encoder_cells[i] for i in range(num_layers)]
            input_feed = Variable(
                x.data.new(bsz, self.encoder_output_units).zero_())

        attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_())
        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            input = torch.cat((x[j, :, :], input_feed), dim=1)

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                input = F.dropout(hidden,
                                  p=self.dropout_out,
                                  training=self.training)

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            # apply attention using the last layer's hidden state
            if self.attention is not None:
                out, attn_scores[:,
                                 j, :] = self.attention(hidden, encoder_outs)
            else:
                out = hidden
            out = F.dropout(out, p=self.dropout_out, training=self.training)

            # input feeding
            input_feed = out

            # save final output
            outs.append(out)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(self, incremental_state, 'cached_state',
                                    (prev_hiddens, prev_cells, input_feed))

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        attn_scores = attn_scores.transpose(0, 2)

        # project back to size of vocabulary
        if hasattr(self, 'additional_fc'):
            x = self.additional_fc(x)
            x = F.dropout(x, p=self.dropout_out, training=self.training)
        x = self.fc_out(x)

        return x, attn_scores
Beispiel #27
0
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        padded_tokens = F.pad(
            prev_output_tokens,
            (self.history_len - 1, 0, 0, 0),
            "constant",
            self.dst_dict.eos(),
        )
        # We use incremental_state only to check whether we are decoding or not
        # self.training is false even for the forward pass through validation
        if incremental_state is not None:
            padded_tokens = padded_tokens[:, -self.history_len - 1:]
        utils.set_incremental_state(self, incremental_state,
                                    "incremental_marker", True)

        bsz, seqlen = padded_tokens.size()
        seqlen -= self.history_len - 1

        # get outputs from encoder
        (encoder_outs, final_hidden, _, src_lengths, _) = encoder_out

        # padded_tokens has shape [batch_size, seq_len+history_len]
        x = self.embed_tokens(padded_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # Convolution needs shape [batch_size, channels, seq_len]
        x = self.history_conv(x.transpose(1, 2)).transpose(1, 2)
        x = F.dropout(x, p=self.dropout_out, training=self.training)

        # x has shape [batch_size, seq_len, channels]
        for i, layer in enumerate(self.layers):
            prev_x = x
            x = layer(x)
            x = F.dropout(x, p=self.dropout_out, training=self.training)
            if self.residual_level is not None and i >= self.residual_level:
                x = x + prev_x

        # Attention
        attn_out, attn_scores = self.attention(
            x.transpose(0, 1).contiguous().view(-1, self.hidden_dim),
            encoder_outs.repeat(1, seqlen, 1),
            src_lengths.repeat(seqlen),
        )
        attn_out = attn_out.view(seqlen, bsz, -1).transpose(1, 0)
        attn_scores = attn_scores.view(-1, seqlen, bsz).transpose(0, 2)
        x = torch.cat((x, attn_out), dim=2)

        # bottleneck layer
        if hasattr(self, "additional_fc"):
            x = self.additional_fc(x)
            x = F.dropout(x, p=self.dropout_out, training=self.training)

        output_projection_w = self.output_projection_w
        output_projection_b = self.output_projection_b

        # avoiding transpose of projection weights during ONNX tracing
        batch_time_hidden = torch.onnx.operators.shape_as_tensor(x)
        x_flat_shape = torch.cat(
            (torch.LongTensor([-1]), batch_time_hidden[2].view(1)))
        x_flat = torch.onnx.operators.reshape_from_tensor_shape(
            x, x_flat_shape)
        projection_flat = torch.matmul(output_projection_w, x_flat.t()).t()
        logits_shape = torch.cat(
            (batch_time_hidden[:2], torch.LongTensor([-1])))
        logits = (torch.onnx.operators.reshape_from_tensor_shape(
            projection_flat, logits_shape) + output_projection_b)
        return logits, attn_scores, None
 def setCached(self, incremental_state, key, value):
     utils.set_incremental_state(self, incremental_state, key, value)
    def forward(self, decoder_in_dict, encoder_out_dict,
                incremental_state=None, incr_doc_step=False, batch_idxs=None, new_incr_cached=None):

        prev_output_tokens = decoder_in_dict['prev_output_tokens']

        encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] # [b x w]
        if encoder_padding_mask is not None:
            encoder_padding_mask = encoder_padding_mask.transpose(0, 1)
        srcbsz, srcdoclen, srcdim = encoder_out_dict['encoder_out'][0].size()  # [b x w x d]

        # summarise whole input for h0 decoder use, verbose but clearer
        src_summary_h0 =  encoder_out_dict['encoder_out'][0].mean(1) # [b x d]

        bsz, doclen, sentlen = prev_output_tokens.size() # these sizes are target ones

        start_doc = 0
        if incremental_state is not None:
            doclen = 1

        # get initial input embedding for document RNN decoder
        x = prev_output_tokens.data.new(bsz).fill_(self.sod_idx)
        x = self.decoder.embed_tokens(x)

        ## Decode sentence states ##

        # initialize previous states (or get from cache during incremental generation)
        cached_state_rnn = utils.get_incremental_state(self, incremental_state, 'cached_state_rnn')
        if incr_doc_step and cached_state_rnn is not None:
            # doing the fist step of the ith (i>1) sentence in incremental generation
            prev_hiddens, prev_cells, input = cached_state_rnn
            outs = [input]

        elif incremental_state is not None \
                and new_incr_cached is not None: # doing subsequents steps of a sentence in incremental generation
            bidxs, old_bsz, reorder_state = batch_idxs
            if reorder_state is not None: # need to do this when some hypotheses have been finished when generating
                # reducing decoding to lower nb of hypotheses
                new_incr_cached = new_incr_cached.index_select(0, reorder_state)
                outs = [new_incr_cached]
            else:
                outs = [new_incr_cached]
        else:
            # first state of first sentence in incremental generation or
            # or first comming here to generate the whole sentence in trainin/scoring
            # previous is h0 with encoder output summary
            outs = []
            encoder_hiddens_cells =  src_summary_h0 # [b x d]
            prev_hiddens = []
            prev_cells = []
            for i in range(len(self.layers)):
                prev_hiddens.append(encoder_hiddens_cells)
                prev_cells.append(encoder_hiddens_cells)
            input = x

        # attn of document decoder over input aggregated units (e.g. encoded sequence of paragraphs)
        attn_scores = x.data.new(srcdoclen, doclen, bsz).zero_()

        if (incremental_state is not None and incr_doc_step) \
                or incremental_state is None:
            for j in range(start_doc, doclen):
                for i, rnn in enumerate(self.layers):
                    # recurrent cell
                    hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                    # hidden state becomes the input to the next layer
                    input = hidden

                    # save state for next time step
                    prev_hiddens[i] = hidden
                    prev_cells[i] = cell

                    # apply attention using the last layer's hidden state (sentence vector)
                if self.wordAttention is not None:

                    # inputs to attention are of the form
                    # input: bsz x input_embed_dim
                    # source_hids: srclen x bsz x output_embed_dim
                    # either attend to the input representation by the cnn encoder or to its combination with the input embeddings
                    if self.hidemb:
                        attn_h, attn_scores_out = self.wordAttention(hidden, \
                                                            encoder_out_dict['encoder_out'][1].transpose(0, 1),\
                                                            encoder_padding_mask)
                    else:
                        attn_h, attn_scores_out = self.wordAttention(hidden, \
                                                                     encoder_out_dict['encoder_out'][0].transpose(0, 1), \
                                                                     encoder_padding_mask)

                    out = attn_h # [b x d]
                else:
                    out = hidden

                # input to next time step
                input = out
                new_incr_cached = out.clone()

                # save final output
                if incremental_state is not None:
                    outs.append(out)
                else:
                    outs.append(out.unsqueeze(1))

        ## Decode sentences ##

        # When training/validation, make all sentence s_t decoding steps in parallel here
        sent_states = None
        if incremental_state is not None:
            dec_outs = x.data.new(bsz, doclen, sentlen, len(self.decoder.dictionary)).zero_()
            # decode by sentence s_j
            for j in range(doclen):

                sp = self.embed_sent_positions(decoder_in_dict['sentence_position'])

                dec_out, word_atte_scores = self.decoder(
                    prev_output_tokens[:, j, :], outs[j], sp, encoder_out_dict, incremental_state,
                    firstfeed=self.firstfeed, normpos=self.normpos)
                # prev_output_tokens is [ b x s x w ], at each time step decode sentence j [b x w]
                # dec_out is [b x w x vocabulary]
                if j == 0:
                    dec_outs = dec_out
                else:
                    dec_outs = torch.cat((dec_outs, dec_out), 1)
                    # dec_outs is [bxs x w x  vocabulary], dim=0
                    # dec_outs is [b x s*w x  vocabulary], dim=1

        else:
            # decode everything in parallel
            sent_states = torch.cat(outs, dim=1).view(bsz*doclen, -1)
            ys = prev_output_tokens.view(bsz*doclen, -1)
            sp = make_sent_positions(prev_output_tokens, self.padding_idx).view(bsz*doclen)
            sp = self.embed_sent_positions(sp)

            # Replicate encoder_out_dict for the new nb of batches to do all in parallel

            ebsz, esrclen, edim = encoder_out_dict['encoder_out'][0].size()
            new_enc_out_dict = {}
            #repeat input for each target
            new_enc_out_dict['encoder_out'] = (
                encoder_out_dict['encoder_out'][0].view(ebsz, 1, esrclen, edim).expand(ebsz, doclen, esrclen, edim)
                                        .contiguous().view(ebsz*doclen, esrclen, edim),
                encoder_out_dict['encoder_out'][1].view(ebsz, 1, esrclen, edim).expand(ebsz, doclen, esrclen, edim)
                                        .contiguous().view(ebsz*doclen, esrclen, edim)
            )

            new_enc_out_dict['encoder_padding_mask'] = None
            if encoder_out_dict['encoder_padding_mask'] is not None:
                new_enc_out_dict['encoder_padding_mask'] = encoder_out_dict['encoder_padding_mask']\
                                                    .view(ebsz, 1, esrclen).expand(ebsz, doclen, esrclen)\
                                                    .contiguous().view(ebsz*doclen, -1)

            #decode all target sentences of all documents in parallel
            dec_out, word_atte_scores = self.decoder(ys, sent_states, sp, new_enc_out_dict,
                                                     firstfeed=self.firstfeed, normpos=self.normpos)
            dec_outs = dec_out.view(bsz, doclen*sentlen, len(self.decoder.dictionary))

        if incremental_state is not None and incr_doc_step:
            # only if we moved to the next document sentence
            # cache previous states (no-op except during incremental generation)
               utils.set_incremental_state(
               self, incremental_state, 'cached_state_rnn', (prev_hiddens, prev_cells, out))

        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        attn_scores = attn_scores.transpose(0, 2)

        return dec_outs, (attn_scores, word_atte_scores), new_incr_cached, None # topic label predictions are None here
Beispiel #30
0
    def extract_features(self,
                         prev_output_tokens,
                         encoder_out=None,
                         incremental_state=None,
                         **unused):
        """
        Similar to *forward* but only return features.

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - attention weights of shape `(batch, tgt_len, src_len)`
        """
        if self.attention is not None:
            assert encoder_out is not None
            encoder_padding_mask = encoder_out['encoder_padding_mask']
            encoder_out = encoder_out['encoder_out']
            # get outputs from encoder
            encoder_outs = encoder_out[0]
            srclen = encoder_outs.size(0)

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # embed tokens
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        else:
            num_layers = len(self.layers)
            prev_hiddens = [x.new_zeros(bsz, self.hidden_size) \
                for i in range(num_layers)]
            prev_cells = [x.new_zeros(bsz, self.hidden_size) \
                for i in range(num_layers)]
            input_feed = x.new_zeros(bsz, self.encoder_output_units) \
                if self.attention is not None else None

        if self.attention is not None:
            attn_scores = x.new_zeros(srclen, seqlen, bsz)
        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            input = torch.cat((x[j, :, :], input_feed), dim=1) \
                if input_feed is not None else x[j, :, :]

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
                if self.residual and i > 0:  # residual connection starts from the 2nd layer
                    prev_layer_hidden = input[:, :hidden.size(1)]

                # compute and apply attention using the 1st layer's hidden state
                if self.attention is not None:
                    if i == 0:
                        context, attn_scores[:, j, :], _ = self.attention(
                            hidden, encoder_outs, encoder_padding_mask)

                    # hidden state concatenated with context vector becomes the
                    # input to the next layer
                    input = torch.cat((hidden, context), dim=1)
                else:
                    input = hidden
                input = F.dropout(input,
                                  p=self.dropout_out,
                                  training=self.training)
                if self.residual and i > 0:
                    if self.attention is not None:
                        hidden_sum = input[:, :hidden.
                                           size(1)] + prev_layer_hidden
                        input = torch.cat(
                            (hidden_sum, input[:, hidden.size(1):]), dim=1)
                    else:
                        input = input + prev_layer_hidden

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            # input feeding
            input_feed = context if self.attention is not None else None

            # save final output
            outs.append(input)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self,
            incremental_state,
            'cached_state',
            (prev_hiddens, prev_cells, input_feed),
        )

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, -1)
        assert x.size(2) == self.hidden_size + self.encoder_output_units

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        if not self.training and self.attention is not None and self.need_attn:
            attn_scores = attn_scores.transpose(0, 2)
        else:
            attn_scores = None

        return x, attn_scores
Beispiel #31
0
 def _set_input_buffer(self, incremental_state, new_buffer):
     return utils.set_incremental_state(self, incremental_state,
                                        'input_buffer', new_buffer)
Beispiel #32
0
    def _forward_given_embeddings(
        self,
        embed_out,
        prev_output_tokens,
        encoder_out,
        incremental_state=None,
        possible_translation_tokens=None,
        timestep=None,
    ):
        x = embed_out
        (encoder_x, src_tokens,
         encoder_padding_mask) = self._unpack_encoder_out(encoder_out)
        bsz, seqlen = prev_output_tokens.size()

        state_outputs = []
        if incremental_state is not None:
            prev_states = utils.get_incremental_state(self, incremental_state,
                                                      "cached_state")
            if prev_states is None:
                prev_states = self._init_prev_states(encoder_out)

            # final 2 states of list are projected key and value
            saved_state = {
                "prev_key": prev_states[-2],
                "prev_value": prev_states[-1]
            }
            self.attention._set_input_buffer(incremental_state, saved_state)

        if incremental_state is not None:
            # first num_layers pairs of states are (prev_hidden, prev_cell)
            # for each layer
            h_prev = prev_states[0]
            c_prev = prev_states[1]
        else:
            h_prev = self._init_hidden(encoder_out, bsz).type_as(x)
            c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x)

        x = self._concat_latent_code(x, encoder_out)
        x, (h_next, c_next) = self.initial_rnn_layer(x, (h_prev, c_prev))
        if incremental_state is not None:
            state_outputs.extend([h_next, c_next])

        x = F.dropout(x, p=self.dropout, training=self.training)

        if self.proj_encoder_layer is not None:
            encoder_x = self.proj_encoder_layer(encoder_x)

        attention_in = x
        if self.proj_layer is not None:
            attention_in = self.proj_layer(x)

        attention_out, attention_weights = self.attention(
            query=attention_in,
            key=encoder_x,
            value=encoder_x,
            key_padding_mask=encoder_padding_mask,
            incremental_state=incremental_state,
            static_kv=True,
            need_weights=(not self.training),
        )

        for i, layer in enumerate(self.extra_rnn_layers):
            residual = x
            rnn_input = torch.cat([x, attention_out], dim=2)
            rnn_input = self._concat_latent_code(rnn_input, encoder_out)

            if incremental_state is not None:
                # first num_layers pairs of states are (prev_hidden, prev_cell)
                # for each layer
                h_prev = prev_states[2 * i + 2]
                c_prev = prev_states[2 * i + 3]
            else:
                h_prev = self._init_hidden(encoder_out, bsz).type_as(x)
                c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x)

            x, (h_next, c_next) = layer(rnn_input, (h_prev, c_prev))
            if incremental_state is not None:
                state_outputs.extend([h_next, c_next])
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + residual

        x = torch.cat([x, attention_out], dim=2)
        x = self._concat_latent_code(x, encoder_out)
        if self.bottleneck_layer is not None:
            x = self.bottleneck_layer(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if (self.vocab_reduction_module is not None
                and possible_translation_tokens is None):
            decoder_input_tokens = prev_output_tokens.contiguous()
            possible_translation_tokens = self.vocab_reduction_module(
                src_tokens, decoder_input_tokens=decoder_input_tokens)

        output_weights = self.embed_out
        if possible_translation_tokens is not None:
            output_weights = output_weights.index_select(
                dim=0, index=possible_translation_tokens)

        logits = F.linear(x, output_weights)

        if incremental_state is not None:
            # encoder projections can be reused at each incremental step
            state_outputs.extend([prev_states[-2], prev_states[-1]])
            utils.set_incremental_state(self, incremental_state,
                                        "cached_state", state_outputs)

        return logits, attention_weights, possible_translation_tokens
Beispiel #33
0
 def _set_input_buffer(self, incremental_state, new_buffer):
     return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)