def translate_rnn(line: str, line_number: int, s2s: S2S_basic.S2S or S2S_attention.S2S, src_vocab: Vocab, tgt_vocab: Vocab, lang_vec: dict, device: torch.device): line = " ".join([src_vocab.start_token, line, src_vocab.end_token]) line = line.split() lang_token = line[1] assert lang_token.startswith("<") and lang_token.endswith(">") # inputs: (input_length,) inputs = torch.tensor([src_vocab.get_index(token) for token in line], device=device) # inputs: (input_length, 1) inputs = inputs.view(-1, 1) if lang_token.startswith("<") and lang_token.endswith(">"): # add language vector # input_embedding: (input_length, 1, embedding_size) # lang_encoding: (embedding_size, ) lang_encoding = torch.tensor(lang_vec[lang_token], device=device) input_embedding = s2s.encoder.embedding(inputs) + lang_encoding else: input_embedding = s2s.encoder.embedding(inputs) print("line {} does not add language embedding".format(line_number)) encoder_output, encoder_hidden_state = s2s.encoder.rnn(input_embedding) decoder_hidden_state = combine_bidir_hidden_state(s2s, encoder_hidden_state) decoder_input = torch.tensor( [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) max_length = (inputs.size(0) - 2) * 3 pred_line = [] for i in range(max_length): # decoder_output: (1, 1, vocab_size) # decoder_hidden_state: (num_layers * num_directions, batch_size, hidden_size) decoder_output, decoder_hidden_state = decode_batch( s2s, decoder_input, decoder_hidden_state, encoder_output) # pred: (1, 1) pred = torch.argmax(decoder_output, dim=2) if tgt_vocab.get_token(pred[0, 0].item()) == tgt_vocab.end_token: break decoder_input = pred pred_line.append(tgt_vocab.get_token(pred[0, 0].item())) return pred_line
def convert_index_to_token(pred_list: List[List], tgt_vocab: Vocab, batch_size: int, end_token_index: int): # pred_line: List[List] (tgt_length, batch_size) pred_line = [] for j in range(batch_size): line = [] for i in range(len(pred_list)): if pred_list[i][j] == end_token_index: break line.append(tgt_vocab.get_token(pred_list[i][j])) pred_line.append(line) return pred_line
def beam_search_transformer(s2s: transformer.S2S, data_tensor: torch.tensor, tgt_vocab: Vocab, beam_size: int, device: torch.device): # src: (1, input_length) src = data_tensor src = src.expand(beam_size, -1) src_mask = s2s.make_src_mask(src) encoder_src = s2s.encoder(src, src_mask) max_length = src.size(1) * 3 # tgt: (1, 1) tgt = torch.tensor([[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) # tgt: (beam_size, 1) tgt = tgt.expand(beam_size, -1) scores = torch.zeros(beam_size, device=device) complete_seqs = [] complete_seqs_scores = [] step = 1 while True: tgt_mask = s2s.make_tgt_mask(tgt) # output: (1 * beam_size, input_length, vocab_size) output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask) # output: (1 * beam_size, vocab_size) output = output[:, -1, :] # output: (1 * beam_size, vocab_size) output = F.log_softmax(output, dim=-1) # sub_sentence_scores: (1 * beam_size, vocab_size) sub_sentence_scores = output + scores.unsqueeze(1) if step == 1: pred_prob, pred_indices = sub_sentence_scores[0].topk(beam_size, dim=-1) else: # sub_sentence_scores: (beam_size * vocab_size) sub_sentence_scores = sub_sentence_scores.view(-1) pred_prob, pred_indices = sub_sentence_scores.topk(beam_size, dim=-1) # beam_id: (beam_size, ) beam_id = pred_indices.floor_divide(len(tgt_vocab)) # token_id: (beam_size, ) token_id = pred_indices % len(tgt_vocab) # next_tgt: (beam_size, input_length + 1) next_tgt = torch.cat([tgt[beam_id], token_id.unsqueeze(1)], dim=1) if step == max_length: complete_seqs.extend(next_tgt.tolist()) complete_seqs_scores.extend(pred_prob.tolist()) break complete_indices = [] for i, indices in enumerate(token_id): if tgt_vocab.get_token(indices.item()) == tgt_vocab.end_token: complete_indices.append(i) if len(complete_indices) > 0: complete_seqs.extend(next_tgt[complete_indices].tolist()) complete_pred_indices = beam_id[complete_indices] * len( tgt_vocab) + token_id[complete_indices] if step == 1: complete_seqs_scores.extend( sub_sentence_scores[0][complete_pred_indices].tolist()) if len(complete_indices) == beam_size: break sub_sentence_scores[0][complete_pred_indices] = -1e9 pred_prob, pred_indices = sub_sentence_scores[0].topk( beam_size, dim=-1) else: complete_seqs_scores.extend( sub_sentence_scores[complete_pred_indices].tolist()) if len(complete_indices) == beam_size: break sub_sentence_scores[complete_pred_indices] = -1e9 pred_prob, pred_indices = sub_sentence_scores.topk(beam_size, dim=-1) # beam_id: (beam_size, ) beam_id = pred_indices.floor_divide(len(tgt_vocab)) # token_id: (beam_size, ) token_id = pred_indices % len(tgt_vocab) # next_tgt: (beam_size, input_length + 1) next_tgt = torch.cat([tgt[beam_id], token_id.unsqueeze(1)], dim=1) step += 1 tgt = next_tgt scores = pred_prob best_sentence_id = 0 for i in range(len(complete_seqs_scores)): if complete_seqs_scores[i] > complete_seqs_scores[best_sentence_id]: best_sentence_id = i best_sentence = complete_seqs[best_sentence_id] best_sentence = [ tgt_vocab.get_token(index) for index in best_sentence[1:-1] ] return best_sentence
def beam_search_rnn(s2s: S2S_attention.S2S or S2S_basic.S2S, data_tensor: torch.tensor, tgt_vocab: Vocab, beam_size: int, device: torch.device): # batch_size == beam_size # inputs: (input_length, beam_size) inputs = data_tensor inputs = inputs.expand(-1, beam_size) encoder_output, encoder_hidden_state = s2s.encoder(inputs) # decoder_input: (1, beam_size) decoder_input = torch.tensor( [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) decoder_input = decoder_input.expand(-1, beam_size) # decoder_hidden_state: (num_layers, beam_size, hidden_size) decoder_hidden_state = combine_bidir_hidden_state(s2s, encoder_hidden_state) max_length = inputs.size(0) * 3 scores = torch.zeros(beam_size, device=device) complete_seqs = [] complete_seqs_scores = [] step = 1 while True: # output: (1, beam_size, vocab_size) # decoder_hidden_state: (num_layers, beam_size, hidden_size) output, decoder_hidden_state = decode_batch( s2s, decoder_input[-1].unsqueeze(0), decoder_hidden_state, encoder_output) output = F.log_softmax(output, dim=-1) # sub_sentence_scores: (beam_size, vocab_size) sub_sentence_scores = scores.unsqueeze(1) + output.squeeze(0) if step == 1: pred_prob, pred_indices = sub_sentence_scores[0].topk(beam_size, dim=-1) else: # sub_sentence_scores: (beam_size * vocab_size) sub_sentence_scores = sub_sentence_scores.view(-1) pred_prob, pred_indices = sub_sentence_scores.topk(beam_size, dim=-1) # beam_id: (beam_size, ) beam_id = pred_indices.floor_divide(len(tgt_vocab)) # token_id: (beam_size, ) token_id = pred_indices % len(tgt_vocab) # decoder_input[-1][beam_id]: (beam_size, ) # next_decoder_input: (step + 1, beam_size) # decoder_input: (step, beam_size) next_decoder_input = torch.cat( [decoder_input[:, beam_id], token_id.unsqueeze(0)], dim=0) if step == max_length: complete_seqs.extend(next_decoder_input.t().tolist()) complete_seqs_scores.extend(pred_prob.tolist()) break complete_indices = [] for i, indices in enumerate(token_id): if tgt_vocab.get_token(indices.item()) == tgt_vocab.end_token: complete_indices.append(i) if len(complete_indices) > 0: complete_seqs.extend( next_decoder_input[:, complete_indices].t().tolist()) complete_pred_indices = beam_id[complete_indices] * len( tgt_vocab) + token_id[complete_indices] if step == 1: complete_seqs_scores.extend( sub_sentence_scores[0][complete_pred_indices].tolist()) if len(complete_pred_indices) == beam_size: break sub_sentence_scores[0][complete_pred_indices] = -1e9 pred_prob, pred_indices = sub_sentence_scores[0].topk( beam_size, dim=-1) else: complete_seqs_scores.extend( sub_sentence_scores[complete_pred_indices].tolist()) if len(complete_pred_indices) == beam_size: break sub_sentence_scores[complete_pred_indices] = -1e9 pred_prob, pred_indices = sub_sentence_scores.topk(beam_size, dim=-1) beam_id = pred_indices.floor_divide(len(tgt_vocab)) token_id = pred_indices % len(tgt_vocab) next_decoder_input = torch.cat( [decoder_input[:, beam_id], token_id.unsqueeze(0)], dim=0) step += 1 if isinstance(decoder_hidden_state, tuple): h, c = decoder_hidden_state h = h[:, beam_id] c = c[:, beam_id] decoder_hidden_state = (h, c) else: decoder_hidden_state = decoder_hidden_state[:, beam_id] decoder_input = next_decoder_input scores = pred_prob best_sentence_id = 0 for i in range(len(complete_seqs_scores)): if complete_seqs_scores[i] > complete_seqs_scores[best_sentence_id]: best_sentence_id = i best_sentence = complete_seqs[best_sentence_id] best_sentence = [ tgt_vocab.get_token(index) for index in best_sentence[1:-1] ] return best_sentence
def translate_transformer(line: str, line_number: int, s2s: transformer.S2S, src_vocab: Vocab, tgt_vocab: Vocab, lang_vec: dict, device: torch.device): line = " ".join([src_vocab.start_token, line, src_vocab.end_token]) line = line.split() max_length = (len(line) - 2) * 3 lang_token = line[1] # inputs: (input_length, ) src = torch.tensor([src_vocab.get_index(token) for token in line], device=device) # inputs: (1, input_length) src = src.view(1, -1) src_mask = s2s.make_src_mask(src) src = s2s.encoder.token_embedding(src) * s2s.encoder.scale # src: (1, input_length, d_model) src = s2s.encoder.pos_embedding(src) if lang_token.startswith("<") and lang_token.endswith(">"): # lang_encoding: (d_model, ) lang_encoding = torch.tensor(lang_vec[lang_token], device=device) src = src + lang_encoding else: print("line {} does not add language embedding".format(line_number)) for layer in s2s.encoder.layers: src, self_attention = layer(src, src_mask) del self_attention encoder_src = src tgt = None pred_line = [tgt_vocab.get_index(tgt_vocab.start_token)] for i in range(max_length): if tgt is None: tgt = torch.tensor([pred_line], device=device) tgt_mask = s2s.make_tgt_mask(tgt) # output: (1, tgt_input_length, vocab_size) output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask) # (1, tgt_input_length) pred = torch.argmax(output, dim=-1)[0, -1] if tgt_vocab.get_token(pred.item()) == tgt_vocab.end_token: break tgt = torch.cat([tgt, pred.unsqueeze(0).unsqueeze(1)], dim=1) pred_line.append(pred.item()) pred_line = [tgt_vocab.get_token(index) for index in pred_line[1:]] return pred_line