def predict_xl(text, model: Transformer, device, is_beam_search=False): input_id = get_input_char_index(text) # input_length = torch.LongTensor([len(input_id)]).to(device) input_tensor = torch.LongTensor([input_id]).to(device) batch_size, src_len = input_tensor.shape trg = input_tensor.new_full((batch_size, 1), model.sos_idx) src_mask, trg_mask = model.make_masks(input_tensor, trg) if is_beam_search == False: # while True: encoder_output = model.encoder(input_tensor, src_mask) step = 0 result = [] while step < 200: # print(step) output = model.decoder(trg, encoder_output, trg_mask, src_mask) output = torch.argmax(output[:, -1], dim=1) result.append(output.item()) if output.numpy()[0] == EOS_IDX: break output = output.unsqueeze(1) trg = torch.cat((trg, output), dim=1) src_mask, trg_mask = model.make_masks(input_tensor, trg) step += 1 # outpu_tensor = torch.argmax(output.squeeze(1), 1) ouput_str = get_output_char(result) return ouput_str else: target = beam_search.beam_decode(input_tensor, model, beam_with=5) print(target) print(len(target[0][0])) ouput_str = get_output_char(target[0][0][1:]) return ouput_str
def main(): device = torch.device("cuda:0" if USE_CUDA else "cpu") env = Environment() END_TAG_IDX = env.lang.word2idx[END_TAG] SAY_HI = "hello" targ_lang = env.lang vocab_inp_size = len(env.lang.word2idx) vocab_tar_size = len(targ_lang.word2idx) print("vocab_inp_size", vocab_inp_size) print("vocab_tar_size", vocab_tar_size) model = Transformer( vocab_inp_size, vocab_tar_size, MAX_TARGET_LEN, d_word_vec=32, d_model=32, d_inner=32, n_layers=3, n_head=4, d_k=32, d_v=32, dropout=0.1, ).to(device) # baseline = Baseline(UNITS) history = [] l_optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) batch = None def maybe_pad_sentence(s): return tf.keras.preprocessing.sequence.pad_sequences( s, maxlen=MAX_TARGET_LEN, padding='post') def get_returns(r: float, seq_len: int): return list(reversed([r * (GAMMA**t) for t in range(seq_len)])) def sentence_to_idxs(sentence: str): return [ env.lang.word2idx[token] for token in tokenize_sentence(sentence) ] for episode in range(EPISODES): # Start of Episode env.reset() model.eval() # get first state from the env state, _, done = env.step(SAY_HI) while not done: src_seq = [ env.lang.word2idx[token] for token in tokenize_sentence(state) ] src_seq, src_pos = collate_fn([src_seq]) src_seq, src_pos = src_seq.to(device), src_pos.to(device) enc_output, *_ = model.encoder(src_seq, src_pos) actions_t = [] actions = [] actions_idx = [] while len(actions) == 0 or actions[len(actions) - 1] != END_TAG_IDX and len( actions) < MAX_TARGET_LEN: # construct new tgt_seq based on what's outputed so far if len(actions_t) == 0: tgt_seq = [env.lang.word2idx[Constants.UNK_WORD]] else: tgt_seq = actions_idx tgt_seq, tgt_pos = collate_fn([tgt_seq]) tgt_seq, tgt_pos = tgt_seq.to(device), tgt_pos.to(device) # dec_output dims: [1, pos, hidden] dec_output, * \ _ = model.decoder(tgt_seq, tgt_pos, src_seq, enc_output) # pick last step dec_output = dec_output[:, -1, :] # w_logits dims: [1, vocab_size] w_logits = model.tgt_word_prj(dec_output) # w_probs dims: [1, vocab_size] w_probs = torch.nn.functional.softmax(w_logits, dim=1) w_dist = torch.distributions.categorical.Categorical( probs=w_probs) w_idx_t = w_dist.sample() w_idx = w_idx_t.cpu().numpy()[0] actions_t.append(w_idx_t) actions_idx.append(w_idx) actions.append(env.lang.idx2word[w_idx]) # action is a sentence (string) action_str = ' '.join(actions) next_state, reward, done = env.step(action_str) # print(reward) history.append((state, actions_t, action_str, reward)) state = next_state # record history (to be used for gradient updating after the episode is done) # End of Episode # Update policy model.train() while len(history) >= BATCH_SIZE: batch = history[:BATCH_SIZE] state_inp_b, action_inp_b, reward_b, ret_seq_b = zip(*[[ sentence_to_idxs(state), actions_b, reward, get_returns(reward, MAX_TARGET_LEN) ] for state, actions_b, _, reward in batch]) action_inp_b = [torch.stack(sent) for sent in action_inp_b] action_inp_b = torch.stack(action_inp_b) ret_seq_b = np.asarray(ret_seq_b) # ret_mean = np.mean(ret_seq_b) # ret_std = np.std(ret_seq_b) # ret_seq_b = (ret_seq_b - ret_mean) / ret_std ret_seq_b = np.exp((ret_seq_b - 0.5) * 5) ret_seq_b = torch.tensor(ret_seq_b, dtype=torch.float32).to(device) loss = 0 # loss_bl=0 l_optimizer.zero_grad() # accumulate gradient with GradientTape src_seq, src_pos = collate_fn(list(state_inp_b)) src_seq, src_pos = src_seq.to(device), src_pos.to(device) enc_output_b, *_ = model.encoder(src_seq, src_pos) max_sentence_len = action_inp_b.shape[1] tgt_seq = [[Constants.BOS] for i in range(BATCH_SIZE)] for t in range(max_sentence_len): # _b stands for batch prev_w_idx_b, tgt_pos = collate_fn(tgt_seq) prev_w_idx_b, tgt_pos = prev_w_idx_b.to(device), tgt_pos.to( device) # dec_output_b dims: [batch, pos, hidden] dec_output_b, *_ = \ model.decoder(prev_w_idx_b, tgt_pos, src_seq, enc_output_b) # pick last step dec_output_b = dec_output_b[:, -1, :] # w_logits_b dims: [batch, vocab_size] w_logits_b = model.tgt_word_prj(dec_output_b) # w_probs dims: [batch, vocab_size] w_probs_b = torch.nn.functional.softmax(w_logits_b, dim=1) dist_b = torch.distributions.categorical.Categorical( probs=w_probs_b) curr_w_idx_b = action_inp_b[:, t, :] log_probs_b = torch.transpose( dist_b.log_prob(torch.transpose(curr_w_idx_b, 0, 1)), 0, 1) # bl_val_b = baseline(tf.cast(dec_hidden_b, 'float32')) # delta_b = ret_b - bl_val_b # cost_b = -tf.math.multiply(log_probs_b, delta_b) # cost_b = -tf.math.multiply(log_probs_b, ret_b) ret_b = torch.reshape(ret_seq_b[:, t], (BATCH_SIZE, 1)).to(device) # alternatively, use torch.mul() but it is overloaded. Might need to try log_probs_b*vec.expand_as(A) cost_b = -torch.mul(log_probs_b, ret_b) # log_probs_b*vec.expand_as(A) # cost_b = -torch.bmm() #if we are doing batch multiplication loss += cost_b # loss_bl += -tf.math.multiply(delta_b, bl_val_b) prev_w_idx_b = curr_w_idx_b tgt_seq = np.append(tgt_seq, prev_w_idx_b.data.cpu().numpy(), axis=1).tolist() # calculate cumulative gradients # model_vars = encoder.variables + decoder.variables loss = loss.mean() loss.backward() # loss_bl.backward() # finally, apply gradient l_optimizer.step() # bl_optimizer.step() # Reset everything for the next episode history = history[BATCH_SIZE:] if episode % max(BATCH_SIZE, 32) == 0 and batch != None: print(">>>>>>>>>>>>>>>>>>>>>>>>>>") print("Episode # ", episode) print("Samples from episode with rewards > 0: ") good_rewards = [(s, a_str, r) for s, _, a_str, r in batch] for s, a, r in random.sample(good_rewards, min(len(good_rewards), 3)): print("prev_state: ", s) print("actions: ", a) print("reward: ", r) # print("return: ", get_returns(r, MAX_TARGET_LEN)) ret_seq_b_np = ret_seq_b.cpu().numpy() print("all returns: min=%f, max=%f, median=%f" % (np.min(ret_seq_b_np), np.max(ret_seq_b_np), np.median(ret_seq_b_np))) print("avg reward: ", sum(reward_b) / len(reward_b)) print("avg loss: ", np.mean(loss.cpu().detach().numpy()))
def beam_decode(src_input, model: Transformer, beam_with=3, topk=1): ''' :param src_input: 输入的 char_id,shape:[seq_pad_length, batch_size] :param src_input_length: src_input的非pad长度 shape:[batch_size] :param model: Seq2Seq模型 :param beam_with: beam search 宽度 :param topk: 生成topk个句子 :return: ''' batch_size, src_len = src_input.shape trg = src_input.new_full((batch_size, 1), model.sos_idx) src_mask, trg_mask = model.make_masks(src_input, trg) encoder_outputs = model.encoder(src_input, src_mask) decode_result = [] for batch_index in range(batch_size): # 当前句子的编码器输出 encoder_output_current = encoder_outputs[batch_index, :, :].unsqueeze( 0) input_tensor = src_input[batch_index, :].unsqueeze(0) trg_current = trg[batch_index, :].unsqueeze(0) word_idx = trg_current[batch_index, -1] # endnodes = [] number_required = min((topk + 1), topk - len(endnodes)) # 优先队列 nodes_queue = PriorityQueue() node = BeamSearchNode(trg_current, None, word_idx, 0, 1) # 将node加入到优先队列 nodes_queue.put((-node.eval(), node)) q_size = 1 # 开始 beam search while True: if q_size > 200: break # 获得 best_node score, n = nodes_queue.get() decoder_input = n.word_index trg = n.decoder_hidden src_mask, trg_mask = model.make_masks(input_tensor, trg) if n.word_index.item( ) == model.eos_idx and n.previous_node != None: endnodes.append((score, n)) if len(endnodes) > number_required: break else: continue # 解码 output = model.decoder(trg, encoder_output_current, trg_mask, src_mask) # result.append(output.item()) # if output.numpy()[0] == EOS_IDX: # break # 获得 beam_with个可能 log_prob, indexs = torch.topk(output, beam_with) # output = torch.argmax(output[:, -1], dim=1) # output = output.unsqueeze(1) # next_nodes = [] for new_k in range(beam_with): decoded_t = indexs[0][-1][new_k].view(-1) log_p = log_prob[0][-1][new_k].item() # output = output[indexs[0][0][new_k]] output = decoded_t.unsqueeze(1) trg_tmp = torch.cat((trg, output), dim=1) node = BeamSearchNode(trg_tmp, n, decoded_t, log_p + n.log_prob, n.length + 1) score = -node.eval() next_nodes.append((score, node)) for i in range(len(next_nodes)): score, nn = next_nodes[i] nodes_queue.put((score, nn)) q_size += len(next_nodes) - 1 if len(endnodes) == 0: endnodes = [nodes_queue.get() for _ in range(topk)] utterances = [] i = 0 for score, n in sorted(endnodes, key=operator.itemgetter(0)): if i >= topk: break unterance = [] unterance.append(n.word_index.numpy()[0]) # 回溯 while n.previous_node != None: n = n.previous_node unterance.append(n.word_index.item()) unterance = unterance[::-1] utterances.append(unterance) i += 1 decode_result.append(utterances) return decode_result