Exemplo n.º 1
0
    def reorder_incremental_state(self, incremental_state, new_order):
        super().reorder_incremental_state(incremental_state, new_order)
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is None:
            return

        def reorder_state(state):
            if isinstance(state, list):
                return [reorder_state(state_i) for state_i in state]
            elif state is not None:
                return state.index_select(0, new_order)
            else:
                return None

        new_state = tuple(map(reorder_state, cached_state))
        utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
Exemplo n.º 2
0
    def extract_features(
        self, prev_output_tokens, encoder_out, incremental_state=None
    ):
        """
        Similar to *forward* but only return features.
        """
        encoder_padding_masks, encoder_outs = {}, {}
        for modality in self.src_modalities:
            if encoder_out[modality] is not None:
                encoder_padding_masks[modality] = encoder_out[modality]['encoder_padding_mask']
                encoder_outs[modality] = encoder_out[modality]['encoder_out']
            else:
                encoder_padding_masks[modality] = None
                encoder_outs[modality] = None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        # get outputs from encoder
        encoder_outputs, encoder_hiddens, encoder_cells, srclen = {}, {}, {}, {}
        for modality in self.src_modalities:
            if encoder_outs[modality] is not None:
                encoder_outputs[modality], encoder_hiddens[modality], encoder_cells[modality] = encoder_outs[modality][
                                                                                                :3]
                srclen[modality] = encoder_outputs[modality].size(0)
                flag = True
            else:
                srclen[modality] = None

        if flag:  # TODO
            # concatenate
            encoder_hiddens = torch.cat([encoder_hiddens[modality]] for modality in self.src_modalities)
            encoder_cells = torch.cat([encoder_cells[modality]] for modality in self.src_modalities)

        # embed tokens
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        elif encoder_out is not None:
            # setup recurrent cells
            num_layers = len(self.layers)
            prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
            prev_cells = [encoder_cells[i] for i in range(num_layers)]
            if self.encoder_hidden_proj is not None:
                prev_hiddens = [self.encoder_hidden_proj(x) for x in prev_hiddens]
                prev_cells = [self.encoder_cell_proj(x) for x in prev_cells]
            input_feed = x.new_zeros(bsz, self.hidden_size)
        else:
            # setup zero cells, since there is no encoder
            num_layers = len(self.layers)
            zero_state = x.new_zeros(bsz, self.hidden_size)
            prev_hiddens = [zero_state for i in range(num_layers)]
            prev_cells = [zero_state for i in range(num_layers)]
            input_feed = None

        assert srclen is not None or self.attention is None, \
            "attention is not supported if there are no encoder outputs"
        attn_scores = x.new_zeros(srclen, seqlen, bsz) if self.attention is not None else None
        outs = []
        for j in range(seqlen):
            # input feeding: concatenate context vector from previous time step
            if input_feed is not None:
                input = torch.cat((x[j, :, :], input_feed), dim=1)
            else:
                input = x[j]

            for i, rnn in enumerate(self.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                # hidden state becomes the input to the next layer
                input = F.dropout(hidden, p=self.dropout_out, training=self.training)

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            # apply attention using the last layer's hidden state
            if self.attention is not None:
                pass
                # out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs, encoder_padding_mask)
            else:
                out = hidden
            out = F.dropout(out, p=self.dropout_out, training=self.training)

            # input feeding
            if input_feed is not None:
                input_feed = out

            # save final output
            outs.append(out)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self, incremental_state, 'cached_state',
            (prev_hiddens, prev_cells, input_feed),
        )

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        if hasattr(self, 'additional_fc') and self.adaptive_softmax is None:
            x = self.additional_fc(x)
            x = F.dropout(x, p=self.dropout_out, training=self.training)

        # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
        if not self.training and self.need_attn and self.attention is not None:
            attn_scores = attn_scores.transpose(0, 2)
        else:
            attn_scores = None
        return x, attn_scores
Exemplo n.º 3
0
    def generate(self, models, sample, **kwargs):
        """Generate a batch of translations.

        Args:
            models (List[~fairseq.models.NccModel]): ensemble of models
            sample (dict): batch
            prefix_tokens (torch.LongTensor, optional): force decoder to begin
                with these tokens
            bos_token (int, optional): beginning of sentence token
                (default: self.eos)
        """
        model = models[0]  # for ensemble expansion

        if not self.retain_dropout:
            model.eval()

        # model.forward normally channels prev_output_tokens into the decoder
        # separately, but SequenceGenerator directly calls model.encoder
        encoder_input = {
            k: v
            for k, v in sample['net_input'].items()
            if k != 'prev_output_tokens'
        }

        src_tokens = encoder_input['src_tokens']
        # batch dimension goes first followed by source lengths
        bsz = src_tokens[0].size(0)

        max_len = self.max_len_b
        assert self.min_len <= max_len, 'min_len cannot be larger than max_len, please adjust these!'

        # 1. encoder
        encoder_out = model.encoder(
            sample['net_input']['src_tokens'],
            src_lengths=sample['net_input']['src_lengths'],
            **kwargs)
        encoder_padding_mask = encoder_out['encoder_padding_mask']
        encoder_out = encoder_out['encoder_out']

        device = encoder_out[0].device

        prev_output_tokens = torch.zeros(bsz, 1).long().fill_(self.eos).to(
            device)  # <eos>
        incremental_state = None

        # get outputs from encoder
        if encoder_out is not None:
            encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3]
            srclen = encoder_outs.size(1)
            encoder_outs = encoder_outs.transpose(0, 1)
        else:
            srclen = None

        # initialize previous states (or get from cache during incremental generation)
        cached_state = utils.get_incremental_state(model.decoder,
                                                   incremental_state,
                                                   'cached_state')
        if cached_state is not None:
            prev_hiddens, prev_cells, input_feed = cached_state
        elif encoder_out is not None:
            # setup recurrent cells
            prev_hiddens = [encoder_hiddens]
            prev_cells = [encoder_cells]
            # equals to torch.zeros(bsz, model.decoder.hidden_size).to(device).type_as(encoder_out[0]),
            # but for support float16, we recommend such implementation
            # input_feed = encoder_out[0].new(bsz, model.decoder.hidden_size).fill_(0)
            input_feed = None
        else:
            # setup zero cells, since there is no encoder
            num_layers = len(model.decoder.layers)
            # for support float16
            # zero_state = torch.zeros(bsz, model.decoder.hidden_size).to(device).type_as(encoder_out[0])
            zero_state = encoder_out[0].new(bsz,
                                            model.decoder.hidden_size).fill_(0)
            prev_hiddens = [zero_state for i in range(num_layers)]
            prev_cells = [zero_state for i in range(num_layers)]
            input_feed = None

        assert srclen is not None or model.decoder.attention is None, \
            "attention is not supported if there are no encoder outputs"
        # attn_scores = torch.zeros(srclen, max_len, bsz).to(device) if model.decoder.attention is not None else None
        attn_scores = encoder_out[0].new(
            srclen, max_len,
            bsz).fill_(0) if model.decoder.attention is not None else None
        dec_preds = []

        # 2. generate
        for j in range(max_len):
            # embed tokens
            prev_output_tokens_emb = model.decoder.embed_tokens(
                prev_output_tokens)
            # B x T x C -> T x B x C
            prev_output_tokens_emb = prev_output_tokens_emb.squeeze(
                1)  # transpose(0, 1)
            # input feeding: concatenate context vector from previous time step
            if input_feed is not None:
                input = torch.cat((prev_output_tokens_emb, input_feed), dim=1)
            else:
                input = prev_output_tokens_emb

            for i, rnn in enumerate(model.decoder.layers):
                # recurrent cell
                hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))

                # save state for next time step
                prev_hiddens[i] = hidden
                prev_cells[i] = cell

            # apply attention using the last layer's hidden state
            if model.decoder.attention is not None:
                out, attn_scores[:, j, :] = model.decoder.attention(
                    hidden, encoder_outs, encoder_padding_mask)
            else:
                out = hidden
            # input feeding
            if input_feed is not None:
                input_feed = out

            decoded = model.decoder.output_layer(
                out)  # (batch_size*comment_dict_size)
            logprobs = F.log_softmax(decoded,
                                     dim=-1)  # (batch_size*comment_dict_size)
            prob_prev = torch.exp(logprobs)  # (batch_size*comment_dict_size)

            # input feeding
            if input_feed is not None:
                input_feed = out

            sample_max = True
            if sample_max:
                sample_logprobs, predicted = torch.max(prob_prev, 1)
                dec_preds.append(predicted.clone())
            else:
                predicted = torch.multinomial(prob_prev, 1)  # .to(device)
                dec_preds.append(predicted.clone())

            # cache previous states (no-op except during incremental generation)
            utils.set_incremental_state(
                self,
                incremental_state,
                'cached_state',
                (prev_hiddens, prev_cells, input_feed),
            )
            prev_output_tokens = predicted.reshape(-1, 1)

        dec_preds = torch.stack(dec_preds, dim=1)

        predictions = []
        for pred in dec_preds.tolist():
            predictions.append([{
                'tokens': torch.Tensor(pred).type_as(dec_preds)
            }])

        return predictions