Example #1
0
    def forward(self, src, src_length):
        encoder_output, encoder_final_state = self.encoder(src, src_length)

        encoder_final_state = [(encoder_final_state[0][i],
                                encoder_final_state[1][i])
                               for i in range(self.num_layers)]

        # Initial decoder initial states
        decoder_initial_states = [
            encoder_final_state,
            self.decoder.lstm_attention.cell.get_initial_states(
                batch_ref=encoder_output, shape=[self.hidden_size])
        ]
        # Build attention mask to avoid paying attention on paddings
        src_mask = (src != self.eos_id).astype(paddle.get_default_dtype())

        encoder_padding_mask = (src_mask - 1.0) * self.INF
        encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1])

        # Tile the batch dimension with beam_size
        encoder_output = nn.BeamSearchDecoder.tile_beam_merge_with_batch(
            encoder_output, self.beam_size)
        encoder_padding_mask = nn.BeamSearchDecoder.tile_beam_merge_with_batch(
            encoder_padding_mask, self.beam_size)

        # Dynamic decoding with beam search
        seq_output, _ = nn.dynamic_decode(
            decoder=self.beam_search_decoder,
            inits=decoder_initial_states,
            max_step_num=self.max_out_len,
            encoder_output=encoder_output,
            encoder_padding_mask=encoder_padding_mask)
        return seq_output
Example #2
0
    def forward(self, trg):
        # Encoder
        latent_z = paddle.normal(shape=(trg.shape[0], self.latent_size))
        dec_first_hidden_cell = self.fc(latent_z)
        dec_first_hidden, dec_first_cell = paddle.split(
            dec_first_hidden_cell, 2, axis=-1)
        if self.num_layers > 1:
            dec_first_hidden = paddle.split(dec_first_hidden, self.num_layers)
            dec_first_cell = paddle.split(dec_first_cell, self.num_layers)
        else:
            dec_first_hidden = [dec_first_hidden]
            dec_first_cell = [dec_first_cell]
        dec_initial_states = [[h, c]
                              for h, c in zip(dec_first_hidden, dec_first_cell)]

        output_fc = lambda x: F.one_hot(
            paddle.multinomial(
                F.softmax(paddle.squeeze(
                    self.decoder.output_fc(x),[1]))),num_classes=self.vocab_size)

        latent_z = nn.BeamSearchDecoder.tile_beam_merge_with_batch(
            latent_z, self.beam_size)

        decoder = nn.BeamSearchDecoder(
            cell=self.decoder.lstm.cell,
            start_token=self.start_token,
            end_token=self.end_token,
            beam_size=self.beam_size,
            embedding_fn=self.decoder.trg_embedder,
            output_fn=output_fc)

        outputs, _ = nn.dynamic_decode(
            decoder,
            inits=dec_initial_states,
            max_step_num=self.max_out_len,
            latent_z=latent_z)
        return outputs
 def model_forward(model, init_hidden, init_cell):
     return dynamic_decode(model.beam_search_decoder,
                           [init_hidden, init_cell],
                           max_step_num=model.max_step_num,
                           impute_finished=True,
                           is_test=True)[0]