def __init__(self, src_vocab_size, trg_vocab_size, embed_dim, hidden_size, num_layers, dropout_prob=0., bos_id=0, eos_id=1, beam_size=4, max_out_len=256): args = dict(locals()) args.pop("self") args.pop("__class__", None) self.bos_id = args.pop("bos_id") self.beam_size = args.pop("beam_size") self.max_out_len = args.pop("max_out_len") self.num_layers = num_layers super(Seq2SeqAttnInferModel, self).__init__(**args) # Dynamic decoder for inference self.beam_search_decoder = nn.BeamSearchDecoder( self.decoder.lstm_attention.cell, start_token=bos_id, end_token=eos_id, beam_size=beam_size, embedding_fn=self.decoder.embedder, output_fn=self.decoder.output_layer)
def __init__(self, vocab_size, embed_dim, hidden_size, num_layers, bos_id=0, eos_id=1, beam_size=4, max_out_len=256): self.bos_id = bos_id self.beam_size = beam_size self.max_out_len = max_out_len self.num_layers = num_layers super(Seq2SeqAttnInferModel, self).__init__(vocab_size, embed_dim, hidden_size, num_layers, eos_id) # Dynamic decoder for inference self.beam_search_decoder = nn.BeamSearchDecoder( self.decoder.lstm_attention.cell, start_token=bos_id, end_token=eos_id, beam_size=beam_size, embedding_fn=self.decoder.embedder, output_fn=self.decoder.output_layer)
def forward(self, trg): # Encoder latent_z = paddle.normal(shape=(trg.shape[0], self.latent_size)) dec_first_hidden_cell = self.fc(latent_z) dec_first_hidden, dec_first_cell = paddle.split( dec_first_hidden_cell, 2, axis=-1) if self.num_layers > 1: dec_first_hidden = paddle.split(dec_first_hidden, self.num_layers) dec_first_cell = paddle.split(dec_first_cell, self.num_layers) else: dec_first_hidden = [dec_first_hidden] dec_first_cell = [dec_first_cell] dec_initial_states = [[h, c] for h, c in zip(dec_first_hidden, dec_first_cell)] output_fc = lambda x: F.one_hot( paddle.multinomial( F.softmax(paddle.squeeze( self.decoder.output_fc(x),[1]))),num_classes=self.vocab_size) latent_z = nn.BeamSearchDecoder.tile_beam_merge_with_batch( latent_z, self.beam_size) decoder = nn.BeamSearchDecoder( cell=self.decoder.lstm.cell, start_token=self.start_token, end_token=self.end_token, beam_size=self.beam_size, embedding_fn=self.decoder.trg_embedder, output_fn=output_fc) outputs, _ = nn.dynamic_decode( decoder, inits=dec_initial_states, max_step_num=self.max_out_len, latent_z=latent_z) return outputs