def forward(self, src, src_length): encoder_output, encoder_final_state = self.encoder(src, src_length) encoder_final_state = [(encoder_final_state[0][i], encoder_final_state[1][i]) for i in range(self.num_layers)] # Initial decoder initial states decoder_initial_states = [ encoder_final_state, self.decoder.lstm_attention.cell.get_initial_states( batch_ref=encoder_output, shape=[self.hidden_size]) ] # Build attention mask to avoid paying attention on paddings src_mask = (src != self.eos_id).astype(paddle.get_default_dtype()) encoder_padding_mask = (src_mask - 1.0) * self.INF encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1]) # Tile the batch dimension with beam_size encoder_output = nn.BeamSearchDecoder.tile_beam_merge_with_batch( encoder_output, self.beam_size) encoder_padding_mask = nn.BeamSearchDecoder.tile_beam_merge_with_batch( encoder_padding_mask, self.beam_size) # Dynamic decoding with beam search seq_output, _ = nn.dynamic_decode( decoder=self.beam_search_decoder, inits=decoder_initial_states, max_step_num=self.max_out_len, encoder_output=encoder_output, encoder_padding_mask=encoder_padding_mask) return seq_output
def forward(self, trg): # Encoder latent_z = paddle.normal(shape=(trg.shape[0], self.latent_size)) dec_first_hidden_cell = self.fc(latent_z) dec_first_hidden, dec_first_cell = paddle.split( dec_first_hidden_cell, 2, axis=-1) if self.num_layers > 1: dec_first_hidden = paddle.split(dec_first_hidden, self.num_layers) dec_first_cell = paddle.split(dec_first_cell, self.num_layers) else: dec_first_hidden = [dec_first_hidden] dec_first_cell = [dec_first_cell] dec_initial_states = [[h, c] for h, c in zip(dec_first_hidden, dec_first_cell)] output_fc = lambda x: F.one_hot( paddle.multinomial( F.softmax(paddle.squeeze( self.decoder.output_fc(x),[1]))),num_classes=self.vocab_size) latent_z = nn.BeamSearchDecoder.tile_beam_merge_with_batch( latent_z, self.beam_size) decoder = nn.BeamSearchDecoder( cell=self.decoder.lstm.cell, start_token=self.start_token, end_token=self.end_token, beam_size=self.beam_size, embedding_fn=self.decoder.trg_embedder, output_fn=output_fc) outputs, _ = nn.dynamic_decode( decoder, inits=dec_initial_states, max_step_num=self.max_out_len, latent_z=latent_z) return outputs
def model_forward(model, init_hidden, init_cell): return dynamic_decode(model.beam_search_decoder, [init_hidden, init_cell], max_step_num=model.max_step_num, impute_finished=True, is_test=True)[0]