def __init__(self, params, vocab): super(Seq2Seq, self).__init__() self.embedding_matrix = load_embedding_matrix() self.params = params self.vocab = vocab self.batch_size = params["batch_size"] self.enc_units = params["enc_units"] self.dec_units = params["dec_units"] self.attn_units = params["attn_units"] self.encoder = Encoder(embedding_matrix=self.embedding_matrix, batch_size=self.batch_size, enc_units=self.enc_units, rnn_type='gru') # self.att_model = 'bahdanau' self.att_model = BahdanauAttention(att_hidden_size=self.dec_units, att_units=self.dec_units) self.decoder = Decoder( embedding_matrix=self.embedding_matrix, dec_units=self.dec_units, rnn_type='gru', att_model=self.att_model, )
def select_att(self, att_model, **config): if att_model in [ 'bahdanau', 'bahdanau_bttention', 'bahdanauattention', 'additive', 'additive_attention', 'additiveattention' ]: try: hidden_size = config['att_hidden_size'] units = config['att_units'] except KeyError: raise ( 'BahdanauAttention模型缺少参数:att_hidden_size, att_units! 请于创建Decoder时传入!' ) return BahdanauAttention(hidden_size, units) elif att_model in [ 'luong', 'luong_bttention', 'luongttention', 'multiplicative', 'multiplicative_attention', 'multiplicativeattention' ]: try: hidden_size = config['att_hidden_size'] attention_func = config['attention_func'] except KeyError: raise ( 'BahdanauAttention模型缺少参数:att_hidden_size, attention_func! 请于创建Decoder时传入!' ) return LuongAttention(hidden_size, attention_func)
# example_input example_input_batch = torch.ones(size=(params['batch_size'], params['input_sequence_len']), dtype=torch.int32, device=device) sample_output, sample_hidden = model.encoder(example_input_batch) # 打印结果 print( 'Encoder output shape: (batch size, sequence length, units) {}'.format( sample_output.shape)) print('Encoder Hidden state shape: (batch size, units) {}'.format( sample_hidden.shape)) attention_layer = BahdanauAttention(params['dec_units'], params['attn_units']).to(device) context_vector, attention_weights = attention_layer( sample_hidden, sample_output) print("Attention context_vector shape: (batch size, units) {}".format( context_vector.shape)) print( "Attention weights shape: (batch_size, sequence_length, 1) {}".format( attention_weights.shape)) sample_decoder_output, _, _ = model.decoder( torch.ones(size=(params['batch_size'], 1), dtype=torch.int32, device=device), sample_hidden, sample_output) print('Decoder output shape: (batch_size, vocab size) {}'.format(
# example_input example_input_batch = torch.ones(size=(batch_size, input_sequence_len), dtype=torch.int32, device=device) # sample input # sample_hidden = encoder.initialize_hidden_state() # sample_hidden = torch.zeros(size=(batch_size, units)) sample_output, sample_hidden = encoder(example_input_batch) # 打印结果 print('Encoder output shape: (batch size, sequence length, hidden_dim) {}'. format(sample_output.shape)) print('Encoder Hidden state shape: (batch size, hidden_dim) {}'.format( sample_hidden.shape)) attention_layer = BahdanauAttention(units, units).to(device) context_vector, attention_weights = attention_layer( sample_hidden, sample_output) print("Attention context_vector shape: (batch size, units) {}".format( context_vector.shape)) print( "Attention weights shape: (batch_size, sequence_length, 1) {}".format( attention_weights.shape)) decoder = Decoder(embedding_matrix=embedding_matrix, dec_units=units, att_model=attention_layer).to(device) sample_decoder_output, state, attention_weights = decoder( torch.ones(size=(batch_size, 1), dtype=torch.int32, device=device), sample_hidden, sample_output)