def greedy_decoding_transformer(s2s: transformer.S2S, data_tensor: torch.tensor, tgt_vocab: Vocab, device: torch.device, tgt_prefix: List[str] = None): # src: (batch_size, input_length) src = data_tensor src_mask = s2s.make_src_mask(src) batch_size = src.size(0) encoder_src = s2s.encoder(src, src_mask) tgt = torch.tensor([[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) tgt = tgt.expand(batch_size, -1) # pred_list: List[List] (tgt_length, batch_size) pred_list = [] if tgt_prefix is not None: # tgt_prefix_tensor: (batch_size, ) tgt_prefix = [ tgt_vocab.get_index(prefix_token) for prefix_token in tgt_prefix ] tgt_prefix_tensor = torch.tensor(tgt_prefix, device=device) # tgt_prefix_tensor: (batch_size, 1) tgt_prefix_tensor = tgt_prefix_tensor.unsqueeze(1) # tgt: (batch_size, 2) tgt = torch.cat([tgt, tgt_prefix_tensor], dim=1) pred_list.append(tgt_prefix) max_length = src.size(1) * 3 end_token_index = tgt_vocab.get_index(tgt_vocab.end_token) for i in range(0 if tgt_prefix is None else 1, max_length): # tgt: (batch_size, i + 1) tgt_mask = s2s.make_tgt_mask(tgt) # output: (batch_size, input_length, vocab_size) output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask) # output: (batch_size, vocab_size) output = output[:, -1, :] # pred: (batch_size, ) pred = torch.argmax(output, dim=-1) if torch.all(pred == end_token_index).item(): break tgt = torch.cat([tgt, pred.unsqueeze(1)], dim=1) pred_list.append(pred.tolist()) return convert_index_to_token(pred_list, tgt_vocab, batch_size, end_token_index)
def translate_rnn(line: str, line_number: int, s2s: S2S_basic.S2S or S2S_attention.S2S, src_vocab: Vocab, tgt_vocab: Vocab, lang_vec: dict, device: torch.device): line = " ".join([src_vocab.start_token, line, src_vocab.end_token]) line = line.split() lang_token = line[1] assert lang_token.startswith("<") and lang_token.endswith(">") # inputs: (input_length,) inputs = torch.tensor([src_vocab.get_index(token) for token in line], device=device) # inputs: (input_length, 1) inputs = inputs.view(-1, 1) if lang_token.startswith("<") and lang_token.endswith(">"): # add language vector # input_embedding: (input_length, 1, embedding_size) # lang_encoding: (embedding_size, ) lang_encoding = torch.tensor(lang_vec[lang_token], device=device) input_embedding = s2s.encoder.embedding(inputs) + lang_encoding else: input_embedding = s2s.encoder.embedding(inputs) print("line {} does not add language embedding".format(line_number)) encoder_output, encoder_hidden_state = s2s.encoder.rnn(input_embedding) decoder_hidden_state = combine_bidir_hidden_state(s2s, encoder_hidden_state) decoder_input = torch.tensor( [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) max_length = (inputs.size(0) - 2) * 3 pred_line = [] for i in range(max_length): # decoder_output: (1, 1, vocab_size) # decoder_hidden_state: (num_layers * num_directions, batch_size, hidden_size) decoder_output, decoder_hidden_state = decode_batch( s2s, decoder_input, decoder_hidden_state, encoder_output) # pred: (1, 1) pred = torch.argmax(decoder_output, dim=2) if tgt_vocab.get_token(pred[0, 0].item()) == tgt_vocab.end_token: break decoder_input = pred pred_line.append(tgt_vocab.get_token(pred[0, 0].item())) return pred_line
def convert_data_to_index(data: List[str], vocab: Vocab): data2index = [] for sentence in data: sentence = " ".join([vocab.start_token, sentence, vocab.end_token]) data2index.append( [vocab.get_index(token) for token in sentence.split()]) return data2index
def greedy_decoding_rnn(s2s: S2S_basic.S2S or S2S_attention.S2S, data_tensor: torch.tensor, tgt_vocab: Vocab, device: torch.device): # inputs: (input_length, batch_size) inputs = data_tensor batch_size = inputs.size(1) encoder_output, encoder_hidden_state = s2s.encoder(inputs) decoder_hidden_state = combine_bidir_hidden_state(s2s, encoder_hidden_state) decoder_input = torch.tensor( [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) decoder_input = decoder_input.expand(-1, batch_size) max_length = inputs.size(0) * 3 # pred_list: List[List] (tgt_length, batch_size) pred_list = [] end_token_index = tgt_vocab.get_index(tgt_vocab.end_token) for i in range(max_length): # decoder_output: (1, batch_size, vocab_size) # decoder_hidden_state: (num_layers * num_directions, batch_size, hidden_size) decoder_output, decoder_hidden_state = decode_batch( s2s, decoder_input, decoder_hidden_state, encoder_output) # pred: (1, batch_size) pred = torch.argmax(decoder_output, dim=2) if torch.all(pred == end_token_index).item(): break decoder_input = pred pred_list.append(pred.squeeze(0).tolist()) return convert_index_to_token(pred_list, tgt_vocab, batch_size, end_token_index)
def load_corpus_data(data_path, language_name, start_token, end_token, mask_token, vocab_path, rebuild_vocab, unk="UNK", threshold=0): if rebuild_vocab: v = Vocab(language_name, start_token, end_token, mask_token, threshold=threshold) corpus = [] with open(data_path) as f: data = f.read().strip().split("\n") for line in data: line = line.strip() line = " ".join([start_token, line, end_token]) if rebuild_vocab: v.add_sentence(line) corpus.append(line) data2index = [] if rebuild_vocab: v.add_unk(unk) v.save(vocab_path) else: v = Vocab.load(vocab_path) for line in corpus: data2index.append([v.get_index(token) for token in line.split()]) return data2index, v
def beam_search_transformer(s2s: transformer.S2S, data_tensor: torch.tensor, tgt_vocab: Vocab, beam_size: int, device: torch.device): # src: (1, input_length) src = data_tensor src = src.expand(beam_size, -1) src_mask = s2s.make_src_mask(src) encoder_src = s2s.encoder(src, src_mask) max_length = src.size(1) * 3 # tgt: (1, 1) tgt = torch.tensor([[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) # tgt: (beam_size, 1) tgt = tgt.expand(beam_size, -1) scores = torch.zeros(beam_size, device=device) complete_seqs = [] complete_seqs_scores = [] step = 1 while True: tgt_mask = s2s.make_tgt_mask(tgt) # output: (1 * beam_size, input_length, vocab_size) output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask) # output: (1 * beam_size, vocab_size) output = output[:, -1, :] # output: (1 * beam_size, vocab_size) output = F.log_softmax(output, dim=-1) # sub_sentence_scores: (1 * beam_size, vocab_size) sub_sentence_scores = output + scores.unsqueeze(1) if step == 1: pred_prob, pred_indices = sub_sentence_scores[0].topk(beam_size, dim=-1) else: # sub_sentence_scores: (beam_size * vocab_size) sub_sentence_scores = sub_sentence_scores.view(-1) pred_prob, pred_indices = sub_sentence_scores.topk(beam_size, dim=-1) # beam_id: (beam_size, ) beam_id = pred_indices.floor_divide(len(tgt_vocab)) # token_id: (beam_size, ) token_id = pred_indices % len(tgt_vocab) # next_tgt: (beam_size, input_length + 1) next_tgt = torch.cat([tgt[beam_id], token_id.unsqueeze(1)], dim=1) if step == max_length: complete_seqs.extend(next_tgt.tolist()) complete_seqs_scores.extend(pred_prob.tolist()) break complete_indices = [] for i, indices in enumerate(token_id): if tgt_vocab.get_token(indices.item()) == tgt_vocab.end_token: complete_indices.append(i) if len(complete_indices) > 0: complete_seqs.extend(next_tgt[complete_indices].tolist()) complete_pred_indices = beam_id[complete_indices] * len( tgt_vocab) + token_id[complete_indices] if step == 1: complete_seqs_scores.extend( sub_sentence_scores[0][complete_pred_indices].tolist()) if len(complete_indices) == beam_size: break sub_sentence_scores[0][complete_pred_indices] = -1e9 pred_prob, pred_indices = sub_sentence_scores[0].topk( beam_size, dim=-1) else: complete_seqs_scores.extend( sub_sentence_scores[complete_pred_indices].tolist()) if len(complete_indices) == beam_size: break sub_sentence_scores[complete_pred_indices] = -1e9 pred_prob, pred_indices = sub_sentence_scores.topk(beam_size, dim=-1) # beam_id: (beam_size, ) beam_id = pred_indices.floor_divide(len(tgt_vocab)) # token_id: (beam_size, ) token_id = pred_indices % len(tgt_vocab) # next_tgt: (beam_size, input_length + 1) next_tgt = torch.cat([tgt[beam_id], token_id.unsqueeze(1)], dim=1) step += 1 tgt = next_tgt scores = pred_prob best_sentence_id = 0 for i in range(len(complete_seqs_scores)): if complete_seqs_scores[i] > complete_seqs_scores[best_sentence_id]: best_sentence_id = i best_sentence = complete_seqs[best_sentence_id] best_sentence = [ tgt_vocab.get_token(index) for index in best_sentence[1:-1] ] return best_sentence
def beam_search_rnn(s2s: S2S_attention.S2S or S2S_basic.S2S, data_tensor: torch.tensor, tgt_vocab: Vocab, beam_size: int, device: torch.device): # batch_size == beam_size # inputs: (input_length, beam_size) inputs = data_tensor inputs = inputs.expand(-1, beam_size) encoder_output, encoder_hidden_state = s2s.encoder(inputs) # decoder_input: (1, beam_size) decoder_input = torch.tensor( [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) decoder_input = decoder_input.expand(-1, beam_size) # decoder_hidden_state: (num_layers, beam_size, hidden_size) decoder_hidden_state = combine_bidir_hidden_state(s2s, encoder_hidden_state) max_length = inputs.size(0) * 3 scores = torch.zeros(beam_size, device=device) complete_seqs = [] complete_seqs_scores = [] step = 1 while True: # output: (1, beam_size, vocab_size) # decoder_hidden_state: (num_layers, beam_size, hidden_size) output, decoder_hidden_state = decode_batch( s2s, decoder_input[-1].unsqueeze(0), decoder_hidden_state, encoder_output) output = F.log_softmax(output, dim=-1) # sub_sentence_scores: (beam_size, vocab_size) sub_sentence_scores = scores.unsqueeze(1) + output.squeeze(0) if step == 1: pred_prob, pred_indices = sub_sentence_scores[0].topk(beam_size, dim=-1) else: # sub_sentence_scores: (beam_size * vocab_size) sub_sentence_scores = sub_sentence_scores.view(-1) pred_prob, pred_indices = sub_sentence_scores.topk(beam_size, dim=-1) # beam_id: (beam_size, ) beam_id = pred_indices.floor_divide(len(tgt_vocab)) # token_id: (beam_size, ) token_id = pred_indices % len(tgt_vocab) # decoder_input[-1][beam_id]: (beam_size, ) # next_decoder_input: (step + 1, beam_size) # decoder_input: (step, beam_size) next_decoder_input = torch.cat( [decoder_input[:, beam_id], token_id.unsqueeze(0)], dim=0) if step == max_length: complete_seqs.extend(next_decoder_input.t().tolist()) complete_seqs_scores.extend(pred_prob.tolist()) break complete_indices = [] for i, indices in enumerate(token_id): if tgt_vocab.get_token(indices.item()) == tgt_vocab.end_token: complete_indices.append(i) if len(complete_indices) > 0: complete_seqs.extend( next_decoder_input[:, complete_indices].t().tolist()) complete_pred_indices = beam_id[complete_indices] * len( tgt_vocab) + token_id[complete_indices] if step == 1: complete_seqs_scores.extend( sub_sentence_scores[0][complete_pred_indices].tolist()) if len(complete_pred_indices) == beam_size: break sub_sentence_scores[0][complete_pred_indices] = -1e9 pred_prob, pred_indices = sub_sentence_scores[0].topk( beam_size, dim=-1) else: complete_seqs_scores.extend( sub_sentence_scores[complete_pred_indices].tolist()) if len(complete_pred_indices) == beam_size: break sub_sentence_scores[complete_pred_indices] = -1e9 pred_prob, pred_indices = sub_sentence_scores.topk(beam_size, dim=-1) beam_id = pred_indices.floor_divide(len(tgt_vocab)) token_id = pred_indices % len(tgt_vocab) next_decoder_input = torch.cat( [decoder_input[:, beam_id], token_id.unsqueeze(0)], dim=0) step += 1 if isinstance(decoder_hidden_state, tuple): h, c = decoder_hidden_state h = h[:, beam_id] c = c[:, beam_id] decoder_hidden_state = (h, c) else: decoder_hidden_state = decoder_hidden_state[:, beam_id] decoder_input = next_decoder_input scores = pred_prob best_sentence_id = 0 for i in range(len(complete_seqs_scores)): if complete_seqs_scores[i] > complete_seqs_scores[best_sentence_id]: best_sentence_id = i best_sentence = complete_seqs[best_sentence_id] best_sentence = [ tgt_vocab.get_token(index) for index in best_sentence[1:-1] ] return best_sentence
def translate_transformer(line: str, line_number: int, s2s: transformer.S2S, src_vocab: Vocab, tgt_vocab: Vocab, lang_vec: dict, device: torch.device): line = " ".join([src_vocab.start_token, line, src_vocab.end_token]) line = line.split() max_length = (len(line) - 2) * 3 lang_token = line[1] # inputs: (input_length, ) src = torch.tensor([src_vocab.get_index(token) for token in line], device=device) # inputs: (1, input_length) src = src.view(1, -1) src_mask = s2s.make_src_mask(src) src = s2s.encoder.token_embedding(src) * s2s.encoder.scale # src: (1, input_length, d_model) src = s2s.encoder.pos_embedding(src) if lang_token.startswith("<") and lang_token.endswith(">"): # lang_encoding: (d_model, ) lang_encoding = torch.tensor(lang_vec[lang_token], device=device) src = src + lang_encoding else: print("line {} does not add language embedding".format(line_number)) for layer in s2s.encoder.layers: src, self_attention = layer(src, src_mask) del self_attention encoder_src = src tgt = None pred_line = [tgt_vocab.get_index(tgt_vocab.start_token)] for i in range(max_length): if tgt is None: tgt = torch.tensor([pred_line], device=device) tgt_mask = s2s.make_tgt_mask(tgt) # output: (1, tgt_input_length, vocab_size) output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask) # (1, tgt_input_length) pred = torch.argmax(output, dim=-1)[0, -1] if tgt_vocab.get_token(pred.item()) == tgt_vocab.end_token: break tgt = torch.cat([tgt, pred.unsqueeze(0).unsqueeze(1)], dim=1) pred_line.append(pred.item()) pred_line = [tgt_vocab.get_token(index) for index in pred_line[1:]] return pred_line