コード例 #1
0
def predict_xl(text, model: Transformer, device, is_beam_search=False):
    input_id = get_input_char_index(text)
    # input_length = torch.LongTensor([len(input_id)]).to(device)
    input_tensor = torch.LongTensor([input_id]).to(device)
    batch_size, src_len = input_tensor.shape
    trg = input_tensor.new_full((batch_size, 1), model.sos_idx)
    src_mask, trg_mask = model.make_masks(input_tensor, trg)

    if is_beam_search == False:
        # while True:
        encoder_output = model.encoder(input_tensor, src_mask)
        step = 0
        result = []
        while step < 200:
            # print(step)
            output = model.decoder(trg, encoder_output, trg_mask, src_mask)
            output = torch.argmax(output[:, -1], dim=1)
            result.append(output.item())
            if output.numpy()[0] == EOS_IDX:
                break
            output = output.unsqueeze(1)
            trg = torch.cat((trg, output), dim=1)
            src_mask, trg_mask = model.make_masks(input_tensor, trg)
            step += 1
        # outpu_tensor = torch.argmax(output.squeeze(1), 1)
        ouput_str = get_output_char(result)
        return ouput_str
    else:
        
        target = beam_search.beam_decode(input_tensor, model, beam_with=5)
        print(target)
        print(len(target[0][0]))
        ouput_str = get_output_char(target[0][0][1:])
        return ouput_str
コード例 #2
0
ファイル: play.py プロジェクト: settinghead/rl-chat
def main():

    device = torch.device("cuda:0" if USE_CUDA else "cpu")

    env = Environment()

    END_TAG_IDX = env.lang.word2idx[END_TAG]

    SAY_HI = "hello"

    targ_lang = env.lang

    vocab_inp_size = len(env.lang.word2idx)
    vocab_tar_size = len(targ_lang.word2idx)

    print("vocab_inp_size", vocab_inp_size)
    print("vocab_tar_size", vocab_tar_size)

    model = Transformer(
        vocab_inp_size,
        vocab_tar_size,
        MAX_TARGET_LEN,
        d_word_vec=32,
        d_model=32,
        d_inner=32,
        n_layers=3,
        n_head=4,
        d_k=32,
        d_v=32,
        dropout=0.1,
    ).to(device)

    # baseline = Baseline(UNITS)

    history = []

    l_optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    batch = None

    def maybe_pad_sentence(s):
        return tf.keras.preprocessing.sequence.pad_sequences(
            s, maxlen=MAX_TARGET_LEN, padding='post')

    def get_returns(r: float, seq_len: int):
        return list(reversed([r * (GAMMA**t) for t in range(seq_len)]))

    def sentence_to_idxs(sentence: str):
        return [
            env.lang.word2idx[token] for token in tokenize_sentence(sentence)
        ]

    for episode in range(EPISODES):

        # Start of Episode
        env.reset()
        model.eval()

        # get first state from the env
        state, _, done = env.step(SAY_HI)

        while not done:

            src_seq = [
                env.lang.word2idx[token] for token in tokenize_sentence(state)
            ]
            src_seq, src_pos = collate_fn([src_seq])
            src_seq, src_pos = src_seq.to(device), src_pos.to(device)
            enc_output, *_ = model.encoder(src_seq, src_pos)
            actions_t = []
            actions = []
            actions_idx = []

            while len(actions) == 0 or actions[len(actions) -
                                               1] != END_TAG_IDX and len(
                                                   actions) < MAX_TARGET_LEN:
                # construct new tgt_seq based on what's outputed so far
                if len(actions_t) == 0:
                    tgt_seq = [env.lang.word2idx[Constants.UNK_WORD]]
                else:
                    tgt_seq = actions_idx
                tgt_seq, tgt_pos = collate_fn([tgt_seq])
                tgt_seq, tgt_pos = tgt_seq.to(device), tgt_pos.to(device)
                # dec_output dims: [1, pos, hidden]
                dec_output, * \
                    _ = model.decoder(tgt_seq, tgt_pos, src_seq, enc_output)
                # pick last step
                dec_output = dec_output[:, -1, :]
                # w_logits dims: [1, vocab_size]
                w_logits = model.tgt_word_prj(dec_output)
                # w_probs dims: [1, vocab_size]
                w_probs = torch.nn.functional.softmax(w_logits, dim=1)
                w_dist = torch.distributions.categorical.Categorical(
                    probs=w_probs)
                w_idx_t = w_dist.sample()
                w_idx = w_idx_t.cpu().numpy()[0]
                actions_t.append(w_idx_t)
                actions_idx.append(w_idx)
                actions.append(env.lang.idx2word[w_idx])

            # action is a sentence (string)
            action_str = ' '.join(actions)
            next_state, reward, done = env.step(action_str)
            # print(reward)
            history.append((state, actions_t, action_str, reward))
            state = next_state

            # record history (to be used for gradient updating after the episode is done)
        # End of Episode
        # Update policy
        model.train()
        while len(history) >= BATCH_SIZE:
            batch = history[:BATCH_SIZE]
            state_inp_b, action_inp_b, reward_b, ret_seq_b = zip(*[[
                sentence_to_idxs(state), actions_b, reward,
                get_returns(reward, MAX_TARGET_LEN)
            ] for state, actions_b, _, reward in batch])
            action_inp_b = [torch.stack(sent) for sent in action_inp_b]
            action_inp_b = torch.stack(action_inp_b)

            ret_seq_b = np.asarray(ret_seq_b)

            # ret_mean = np.mean(ret_seq_b)
            # ret_std = np.std(ret_seq_b)
            # ret_seq_b = (ret_seq_b - ret_mean) / ret_std
            ret_seq_b = np.exp((ret_seq_b - 0.5) * 5)

            ret_seq_b = torch.tensor(ret_seq_b, dtype=torch.float32).to(device)

            loss = 0
            # loss_bl=0
            l_optimizer.zero_grad()
            # accumulate gradient with GradientTape
            src_seq, src_pos = collate_fn(list(state_inp_b))
            src_seq, src_pos = src_seq.to(device), src_pos.to(device)
            enc_output_b, *_ = model.encoder(src_seq, src_pos)
            max_sentence_len = action_inp_b.shape[1]
            tgt_seq = [[Constants.BOS] for i in range(BATCH_SIZE)]
            for t in range(max_sentence_len):
                # _b stands for batch
                prev_w_idx_b, tgt_pos = collate_fn(tgt_seq)
                prev_w_idx_b, tgt_pos = prev_w_idx_b.to(device), tgt_pos.to(
                    device)
                # dec_output_b dims: [batch, pos, hidden]
                dec_output_b, *_ = \
                    model.decoder(prev_w_idx_b, tgt_pos, src_seq, enc_output_b)
                # pick last step
                dec_output_b = dec_output_b[:, -1, :]
                # w_logits_b dims: [batch, vocab_size]
                w_logits_b = model.tgt_word_prj(dec_output_b)
                # w_probs dims: [batch, vocab_size]
                w_probs_b = torch.nn.functional.softmax(w_logits_b, dim=1)

                dist_b = torch.distributions.categorical.Categorical(
                    probs=w_probs_b)
                curr_w_idx_b = action_inp_b[:, t, :]
                log_probs_b = torch.transpose(
                    dist_b.log_prob(torch.transpose(curr_w_idx_b, 0, 1)), 0, 1)

                # bl_val_b = baseline(tf.cast(dec_hidden_b, 'float32'))
                # delta_b = ret_b - bl_val_b

                # cost_b = -tf.math.multiply(log_probs_b, delta_b)
                # cost_b = -tf.math.multiply(log_probs_b, ret_b)
                ret_b = torch.reshape(ret_seq_b[:, t],
                                      (BATCH_SIZE, 1)).to(device)
                # alternatively, use torch.mul() but it is overloaded. Might need to try log_probs_b*vec.expand_as(A)
                cost_b = -torch.mul(log_probs_b, ret_b)
                #  log_probs_b*vec.expand_as(A)
                # cost_b = -torch.bmm()   #if we are doing batch multiplication

                loss += cost_b
                # loss_bl += -tf.math.multiply(delta_b, bl_val_b)

                prev_w_idx_b = curr_w_idx_b
                tgt_seq = np.append(tgt_seq,
                                    prev_w_idx_b.data.cpu().numpy(),
                                    axis=1).tolist()

            # calculate cumulative gradients

            # model_vars = encoder.variables + decoder.variables
            loss = loss.mean()
            loss.backward()
            # loss_bl.backward()

            # finally, apply gradient

            l_optimizer.step()
            # bl_optimizer.step()

            # Reset everything for the next episode
            history = history[BATCH_SIZE:]

        if episode % max(BATCH_SIZE, 32) == 0 and batch != None:
            print(">>>>>>>>>>>>>>>>>>>>>>>>>>")
            print("Episode # ", episode)
            print("Samples from episode with rewards > 0: ")
            good_rewards = [(s, a_str, r) for s, _, a_str, r in batch]
            for s, a, r in random.sample(good_rewards,
                                         min(len(good_rewards), 3)):
                print("prev_state: ", s)
                print("actions: ", a)
                print("reward: ", r)
                # print("return: ", get_returns(r, MAX_TARGET_LEN))
            ret_seq_b_np = ret_seq_b.cpu().numpy()
            print("all returns: min=%f, max=%f, median=%f" %
                  (np.min(ret_seq_b_np), np.max(ret_seq_b_np),
                   np.median(ret_seq_b_np)))
            print("avg reward: ", sum(reward_b) / len(reward_b))
            print("avg loss: ", np.mean(loss.cpu().detach().numpy()))
コード例 #3
0
def beam_decode(src_input, model: Transformer, beam_with=3, topk=1):
    '''

    :param src_input: 输入的 char_id,shape:[seq_pad_length, batch_size]
    :param src_input_length: src_input的非pad长度 shape:[batch_size]
    :param model: Seq2Seq模型
    :param beam_with:   beam search 宽度
    :param topk:    生成topk个句子
    :return:
    '''
    batch_size, src_len = src_input.shape
    trg = src_input.new_full((batch_size, 1), model.sos_idx)

    src_mask, trg_mask = model.make_masks(src_input, trg)
    encoder_outputs = model.encoder(src_input, src_mask)
    decode_result = []
    for batch_index in range(batch_size):
        # 当前句子的编码器输出
        encoder_output_current = encoder_outputs[batch_index, :, :].unsqueeze(
            0)
        input_tensor = src_input[batch_index, :].unsqueeze(0)
        trg_current = trg[batch_index, :].unsqueeze(0)
        word_idx = trg_current[batch_index, -1]

        #
        endnodes = []
        number_required = min((topk + 1), topk - len(endnodes))
        # 优先队列
        nodes_queue = PriorityQueue()
        node = BeamSearchNode(trg_current, None, word_idx, 0, 1)

        # 将node加入到优先队列
        nodes_queue.put((-node.eval(), node))
        q_size = 1

        # 开始 beam search
        while True:
            if q_size > 200:
                break
            # 获得 best_node
            score, n = nodes_queue.get()
            decoder_input = n.word_index
            trg = n.decoder_hidden
            src_mask, trg_mask = model.make_masks(input_tensor, trg)
            if n.word_index.item(
            ) == model.eos_idx and n.previous_node != None:
                endnodes.append((score, n))
                if len(endnodes) > number_required:
                    break
                else:
                    continue

            # 解码
            output = model.decoder(trg, encoder_output_current, trg_mask,
                                   src_mask)

            # result.append(output.item())
            # if output.numpy()[0] == EOS_IDX:
            #     break

            # 获得 beam_with个可能
            log_prob, indexs = torch.topk(output, beam_with)
            # output = torch.argmax(output[:, -1], dim=1)
            # output = output.unsqueeze(1)

            #

            next_nodes = []

            for new_k in range(beam_with):
                decoded_t = indexs[0][-1][new_k].view(-1)
                log_p = log_prob[0][-1][new_k].item()
                # output = output[indexs[0][0][new_k]]
                output = decoded_t.unsqueeze(1)
                trg_tmp = torch.cat((trg, output), dim=1)
                node = BeamSearchNode(trg_tmp, n, decoded_t,
                                      log_p + n.log_prob, n.length + 1)
                score = -node.eval()
                next_nodes.append((score, node))

            for i in range(len(next_nodes)):
                score, nn = next_nodes[i]
                nodes_queue.put((score, nn))

            q_size += len(next_nodes) - 1

        if len(endnodes) == 0:
            endnodes = [nodes_queue.get() for _ in range(topk)]

        utterances = []
        i = 0
        for score, n in sorted(endnodes, key=operator.itemgetter(0)):
            if i >= topk:
                break
            unterance = []
            unterance.append(n.word_index.numpy()[0])

            # 回溯
            while n.previous_node != None:
                n = n.previous_node
                unterance.append(n.word_index.item())

            unterance = unterance[::-1]
            utterances.append(unterance)
            i += 1
        decode_result.append(utterances)
        return decode_result