def get_hiddens(encoder, decoder, sentence, vocab, batch_size, max_length=MAX_LENGTH): """ return hidden vectors h_ba, h_tilda (same notation as Luong et al. (2015) """ with torch.no_grad(): input_tensor = prep.tensorFromSentenceBatchWithPadding(vocab, [sentence]) # because of batch, need expansion for input tensor temp = input_tensor for _ in range(batch_size-1): temp = torch.cat((temp, input_tensor), 0) input_tensor = temp input_tensor = input_tensor.transpose(0, 1) encoder_hidden = encoder.init_hidden(batch_size) encoder_outputs = torch.zeros(max_length, batch_size, encoder.hidden_size, device=device) encoder_h_bar = torch.zeros(max_length, encoder.hidden_size, device=device) for ei in range(max_length): it = input_tensor[ei].view(batch_size, -1) encoder_output, encoder_hidden = encoder(it, encoder_hidden) encoder_outputs[ei] = encoder_output.transpose(1, 2).view(batch_size, encoder.hidden_size) encoder_h_bar[ei] = encoder_hidden[0][0] decoder_input = torch.tensor([batch_size * [SOS_token]], device=device).view(batch_size, 1) # SOS decoder_hidden = encoder_hidden decoder_h_tilda = torch.zeros(max_length, decoder.hidden_size, device=device) for di in range(max_length): decoder_output, decoder_hidden, decoder_attention, h_tilda = decoder( decoder_input, decoder_hidden, encoder_outputs) topv, topi = decoder_output.data.topk(1) decoder_input = topi.squeeze().view(1, batch_size) decoder_h_tilda[di] = h_tilda[0] return encoder_h_bar, decoder_h_tilda
def trainIters(args, epoch, encoder, decoder, n_iters, pairs, vocab, train_loader, print_every=1000, plot_every=100, learning_rate=0.01): start = time.time() plot_losses = [] print_loss_total = 0 # Reset every print_every plot_loss_total = 0 # Reset every plot_every if args.optim == 'RMSprop': encoder_optimizer = optim.RMSprop(encoder.parameters(), lr=learning_rate) decoder_optimizer = optim.RMSprop(decoder.parameters(), lr=learning_rate) elif args.optim == 'Adam': encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate) decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate) elif args.optim == 'SGD': encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate) decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate) criterion = nn.NLLLoss() num_iters = 0 for _iter, (batch_input, batch_target) in enumerate(train_loader): input_tensor = prep.tensorFromSentenceBatchWithPadding( vocab, batch_input) target_tensor = prep.tensorFromSentenceBatchWithPadding( vocab, batch_target) loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion) print_loss_total += loss plot_loss_total += loss num_iters += batch_size if epoch % print_every == 0: print_loss_avg = print_loss_total / num_iters print_loss_total = 0 print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_iters), epoch, epoch / n_iters * 100, print_loss_avg))
def evaluate(encoder, decoder, sentence, vocab, batch_size, max_length=MAX_LENGTH): with torch.no_grad(): input_tensor = prep.tensorFromSentenceBatchWithPadding(vocab, sentence) encoder_hidden = encoder.init_hidden(batch_size) encoder_outputs = torch.zeros(max_length, batch_size, encoder.hidden_size, device=device) input_tensor = input_tensor.transpose(0, 1) for ei in range(max_length): it = input_tensor[ei].view(batch_size, -1) encoder_output, encoder_hidden = encoder(it, encoder_hidden) encoder_outputs[ei] = encoder_output.transpose(1, 2).view( batch_size, encoder.hidden_size) decoder_input = torch.tensor([batch_size * [SOS_token]], device=device).view(batch_size, 1) # SOS decoder_hidden = encoder_hidden decoded_words_batch = [] for _ in range(batch_size): decoded_words_batch.append([]) #print(decoder_input) for di in range(max_length): decoder_output, decoder_hidden, decoder_attention, _ = decoder( # decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden) decoder_input, decoder_hidden, encoder_outputs) topv, topi = decoder_output.data.topk(1) decoder_input = topi.squeeze().view(1, batch_size) #print(decoder_input) #print(decoder_output.size()) for i, out in enumerate(decoder_output): top = out.data.topk(1)[1] #print(top.item()) if top.item() == EOS_token: decoded_words_batch[i].append('<EOS>') else: decoded_words_batch[i].append(vocab.index2word[top.item()]) #return decoded_words, decoder_attentions[:di + 1] #print(decoded_words_batch) return decoded_words_batch
def get_attn_hidden_avg(encoder, decoder, sentence, vocab, batch_size, max_length=MAX_LENGTH): with torch.no_grad(): # because of batch, need expansion for input tensor sent = [] for _ in range(batch_size): sent.append(sentence) input_tensor = prep.tensorFromSentenceBatchWithPadding(vocab, sent) input_tensor = input_tensor.transpose(0, 1) encoder_hidden = encoder.init_hidden(batch_size) encoder_outputs = torch.zeros(max_length, batch_size, encoder.hidden_size, device=device) for ei in range(max_length): it = input_tensor[ei].view(batch_size, -1) encoder_output, encoder_hidden = encoder(it, encoder_hidden) encoder_outputs[ei] = encoder_output.transpose(1, 2).view( batch_size, encoder.hidden_size) decoder_input = torch.tensor([batch_size * [SOS_token]], device=device).view(batch_size, 1) # SOS decoder_hidden = encoder_hidden ah_matrix = torch.zeros(max_length, batch_size, encoder.hidden_size, device=device) for di in range(max_length): decoder_output, decoder_hidden, decoder_attention = decoder( decoder_input, decoder_hidden, encoder_outputs) topv, topi = decoder_output.data.topk(1) decoder_input = topi.squeeze().view(1, batch_size) #decoder_attn = decoder_attention.transpose(0, 1)[di].unsqueeze(0) #decoder_attn = decoder_attn.transpose(0, 1) #decoder_attn = decoder_attn.unsqueeze(1) #ah_matrix[di] = torch.matmul(decoder_attn, decoder_hidden.transpose(0, 1)).squeeze(1) ah_matrix[di] = decoder_hidden[0] # ah_matrix = (15, 40, 128) -> (40, 15, 128) ah_matrix = ah_matrix.transpose(0, 1)[0] return torch.mean(ah_matrix, 0)
def get_embed(encoder, sentence, vocab, batch_size, max_length=MAX_LENGTH): with torch.no_grad(): # for batch, need expansion for input tensor sent = [] for _ in range(batch_size): sent.append(sentence) input_tensor = prep.tensorFromSentenceBatchWithPadding(vocab, sent) encoder_hidden = encoder.init_hidden(batch_size) encoder_output, encoder_hidden = encoder(input_tensor, encoder_hidden) # consider last encoder_hidden as sentence embedding return encoder_hidden[0][0].view(1, 1, -1)
def get_ende_hidden(encoder, decoder, sentence, vocab, batch_size, max_length=MAX_LENGTH): with torch.no_grad(): input_tensor = prep.tensorFromSentenceBatchWithPadding( vocab, [sentence]) # because of batch, need expansion for input tensor temp = input_tensor for _ in range(batch_size - 1): temp = torch.cat((temp, input_tensor), 0) input_tensor = temp input_tensor = input_tensor.transpose(0, 1) encoder_hidden = encoder.init_hidden(batch_size) encoder_outputs = torch.zeros(max_length, batch_size, encoder.hidden_size, device=device) for ei in range(max_length): it = input_tensor[ei].view(batch_size, -1) encoder_output, encoder_hidden = encoder(it, encoder_hidden) encoder_outputs[ei] = encoder_output.transpose(1, 2).view( batch_size, encoder.hidden_size) decoder_input = torch.tensor([batch_size * [SOS_token]], device=device).view(batch_size, 1) # SOS decoder_hidden = encoder_hidden #print(decoder_input) for di in range(max_length): decoder_output, decoder_hidden, decoder_attention = decoder( decoder_input, decoder_hidden, encoder_outputs) topv, topi = decoder_output.data.topk(1) decoder_input = topi.squeeze().view(1, batch_size) # concat two hidden vector of encoder, decoder C_Q = encoder_hidden[0][0].view(1, -1) C_A = decoder_hidden[0][0].view(1, -1) return C_Q, C_A
def get_embed_avg(encoder, vocab, sentence): """ fine-tuned word embedding average """ with torch.no_grad(): input_tensor = prep.tensorFromSentenceBatchWithPadding( vocab, [sentence]) # because of batch, need expansion for input tensor #temp = input_tensor #for _ in range(batch_size-1): # temp = torch.cat((temp, input_tensor), 0) #input_tensor = temp #input_tensor = input_tensor.transpose(0, 1) embedded = encoder.embedding(input_tensor) embedded = embedded[0] embedded = embedded.mean(dim=0) #print(embedded.size()) return embedded
def get_embed_ans_pivot(encoder, decoder, sentence, vocab, batch_size, max_length=MAX_LENGTH): """ sentence embedding test v3. answer attentioned vector in light of question vector """ with torch.no_grad(): input_tensor = prep.tensorFromSentenceBatchWithPadding( vocab, [sentence]) # because of batch, need expansion for input tensor temp = input_tensor for _ in range(batch_size - 1): temp = torch.cat((temp, input_tensor), 0) input_tensor = temp # input_tensor = input_tensor.transpose(0, 1) encoder_hidden = encoder.init_hidden(batch_size) encoder_output, encoder_hidden = encoder(input_tensor, encoder_hidden) decoder_input = torch.tensor([[SOS_token] * batch_size], device=device).view(1, batch_size) # SOS decoder_hidden = encoder_hidden for di in range(max_length): decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden) topv, topi = decoder_output.data.topk(1) decoder_input = topi.squeeze().view(1, batch_size) C_Q = encoder_hidden[0][0].view(1, -1) C_A = decoder_hidden[0][0].view(1, -1) L = torch.matmul(C_Q.transpose(0, 1), C_A) A_A = softmax(L) C_AQ = torch.matmul(C_Q, A_A) return C_AQ
def get_embed_concat(encoder, decoder, sentence, vocab, batch_size, max_length=MAX_LENGTH): """ sentence embedding test v1. concat two hidden vector of encoder and decoder """ with torch.no_grad(): input_tensor = prep.tensorFromSentenceBatchWithPadding( vocab, [sentence]) # because of batch, need expansion for input tensor temp = input_tensor for _ in range(batch_size - 1): temp = torch.cat((temp, input_tensor), 0) input_tensor = temp input_tensor = input_tensor.transpose(0, 1) encoder_hidden = encoder.init_hidden(batch_size) encoder_output, encoder_hidden = encoder(input_tensor, encoder_hidden) decoder_input = torch.tensor([[SOS_token] * batch_size], device=device).view(1, batch_size) # SOS decoder_hidden = encoder_hidden #print(decoder_input) for di in range(max_length): decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden) topv, topi = decoder_output.data.topk(1) decoder_input = topi.squeeze().view(1, batch_size) # concat two hidden vector of encoder, decoder C_Q = encoder_hidden[0][0].view(1, -1) C_A = decoder_hidden[0][0].view(1, -1) return torch.cat((C_Q, C_A), 0)