Ejemplo n.º 1
0
def generate_answer(sentence, model, inp_lang, targ_lang, max_length_inp,
                    max_length_tar):
    inputs = [inp_lang.word2idx[i] for i in tokenize_sentence(sentence)]
    inputs = tf.keras.preprocessing.sequence.pad_sequences(
        [inputs], maxlen=max_length_inp, padding='post')
    inputs = tf.convert_to_tensor(inputs)

    result = ''

    hidden = [tf.zeros((1, units))]
    enc_out, enc_hidden = model.encoder(inputs, hidden)

    dec_hidden = enc_hidden
    dec_input = tf.expand_dims([targ_lang.word2idx['<go>']], 0)

    for t in range(max_length_tar):
        predictions, dec_hidden = model.decoder(dec_input, dec_hidden)
        predicted_id = tf.argmax(predictions[0]).numpy()

        result += ' ' + targ_lang.idx2word[predicted_id]

        if targ_lang.idx2word[predicted_id] == '<eos>':
            return result, sentence

        # the predicted ID is fed back into the model
        dec_input = tf.expand_dims([predicted_id], 0)

    return result, sentence
Ejemplo n.º 2
0
 def calc_reward(self,
                 utterance1: str,
                 utterance2: str,
                 exclude_tokens=[Constants.EOS, Constants.PAD, Constants.BOS]):
     # calc string distance
     token_seq1 = [self.lang.word2idx[t]
                   for t in tokenize_sentence(utterance1)]
     token_seq2 = [self.lang.word2idx[t]
                   for t in tokenize_sentence(utterance2)]
     seq1 = [t for t in token_seq1 if t not in exclude_tokens]
     seq2 = [t for t in token_seq2 if t not in exclude_tokens]
     r = SequenceMatcher(None, seq1, seq2).ratio()
     # if(r > 0):
     #     print([self.lang.idx2word[idx]
     #            for idx in set(seq2).intersection(set(seq1))])
     return r
Ejemplo n.º 3
0
 def step(self, action: str):
     action = action.lower()
     if(self.by_word):
         reward = self.calc_reward_w(action)
     else:
         # TODO: better sentence level sentiment analysis
         reward = np.mean([self.calc_reward_w(w) for w in tokenize_sentence(action)])
     done = len(self.history) > CONVO_LEN
     self.history.append(action)
     state = random.sample(questions, 1)[0]
     return state, reward, done
Ejemplo n.º 4
0
 def sentence_to_idxs(sentence: str):
     return [
         env.lang.word2idx[token] for token in tokenize_sentence(sentence)
     ]
Ejemplo n.º 5
0
def main():

    device = torch.device("cuda:0" if USE_CUDA else "cpu")

    env = Environment()

    END_TAG_IDX = env.lang.word2idx[END_TAG]

    SAY_HI = "hello"

    targ_lang = env.lang

    vocab_inp_size = len(env.lang.word2idx)
    vocab_tar_size = len(targ_lang.word2idx)

    print("vocab_inp_size", vocab_inp_size)
    print("vocab_tar_size", vocab_tar_size)

    model = Transformer(
        vocab_inp_size,
        vocab_tar_size,
        MAX_TARGET_LEN,
        d_word_vec=32,
        d_model=32,
        d_inner=32,
        n_layers=3,
        n_head=4,
        d_k=32,
        d_v=32,
        dropout=0.1,
    ).to(device)

    # baseline = Baseline(UNITS)

    history = []

    l_optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    batch = None

    def maybe_pad_sentence(s):
        return tf.keras.preprocessing.sequence.pad_sequences(
            s, maxlen=MAX_TARGET_LEN, padding='post')

    def get_returns(r: float, seq_len: int):
        return list(reversed([r * (GAMMA**t) for t in range(seq_len)]))

    def sentence_to_idxs(sentence: str):
        return [
            env.lang.word2idx[token] for token in tokenize_sentence(sentence)
        ]

    for episode in range(EPISODES):

        # Start of Episode
        env.reset()
        model.eval()

        # get first state from the env
        state, _, done = env.step(SAY_HI)

        while not done:

            src_seq = [
                env.lang.word2idx[token] for token in tokenize_sentence(state)
            ]
            src_seq, src_pos = collate_fn([src_seq])
            src_seq, src_pos = src_seq.to(device), src_pos.to(device)
            enc_output, *_ = model.encoder(src_seq, src_pos)
            actions_t = []
            actions = []
            actions_idx = []

            while len(actions) == 0 or actions[len(actions) -
                                               1] != END_TAG_IDX and len(
                                                   actions) < MAX_TARGET_LEN:
                # construct new tgt_seq based on what's outputed so far
                if len(actions_t) == 0:
                    tgt_seq = [env.lang.word2idx[Constants.UNK_WORD]]
                else:
                    tgt_seq = actions_idx
                tgt_seq, tgt_pos = collate_fn([tgt_seq])
                tgt_seq, tgt_pos = tgt_seq.to(device), tgt_pos.to(device)
                # dec_output dims: [1, pos, hidden]
                dec_output, * \
                    _ = model.decoder(tgt_seq, tgt_pos, src_seq, enc_output)
                # pick last step
                dec_output = dec_output[:, -1, :]
                # w_logits dims: [1, vocab_size]
                w_logits = model.tgt_word_prj(dec_output)
                # w_probs dims: [1, vocab_size]
                w_probs = torch.nn.functional.softmax(w_logits, dim=1)
                w_dist = torch.distributions.categorical.Categorical(
                    probs=w_probs)
                w_idx_t = w_dist.sample()
                w_idx = w_idx_t.cpu().numpy()[0]
                actions_t.append(w_idx_t)
                actions_idx.append(w_idx)
                actions.append(env.lang.idx2word[w_idx])

            # action is a sentence (string)
            action_str = ' '.join(actions)
            next_state, reward, done = env.step(action_str)
            # print(reward)
            history.append((state, actions_t, action_str, reward))
            state = next_state

            # record history (to be used for gradient updating after the episode is done)
        # End of Episode
        # Update policy
        model.train()
        while len(history) >= BATCH_SIZE:
            batch = history[:BATCH_SIZE]
            state_inp_b, action_inp_b, reward_b, ret_seq_b = zip(*[[
                sentence_to_idxs(state), actions_b, reward,
                get_returns(reward, MAX_TARGET_LEN)
            ] for state, actions_b, _, reward in batch])
            action_inp_b = [torch.stack(sent) for sent in action_inp_b]
            action_inp_b = torch.stack(action_inp_b)

            ret_seq_b = np.asarray(ret_seq_b)

            # ret_mean = np.mean(ret_seq_b)
            # ret_std = np.std(ret_seq_b)
            # ret_seq_b = (ret_seq_b - ret_mean) / ret_std
            ret_seq_b = np.exp((ret_seq_b - 0.5) * 5)

            ret_seq_b = torch.tensor(ret_seq_b, dtype=torch.float32).to(device)

            loss = 0
            # loss_bl=0
            l_optimizer.zero_grad()
            # accumulate gradient with GradientTape
            src_seq, src_pos = collate_fn(list(state_inp_b))
            src_seq, src_pos = src_seq.to(device), src_pos.to(device)
            enc_output_b, *_ = model.encoder(src_seq, src_pos)
            max_sentence_len = action_inp_b.shape[1]
            tgt_seq = [[Constants.BOS] for i in range(BATCH_SIZE)]
            for t in range(max_sentence_len):
                # _b stands for batch
                prev_w_idx_b, tgt_pos = collate_fn(tgt_seq)
                prev_w_idx_b, tgt_pos = prev_w_idx_b.to(device), tgt_pos.to(
                    device)
                # dec_output_b dims: [batch, pos, hidden]
                dec_output_b, *_ = \
                    model.decoder(prev_w_idx_b, tgt_pos, src_seq, enc_output_b)
                # pick last step
                dec_output_b = dec_output_b[:, -1, :]
                # w_logits_b dims: [batch, vocab_size]
                w_logits_b = model.tgt_word_prj(dec_output_b)
                # w_probs dims: [batch, vocab_size]
                w_probs_b = torch.nn.functional.softmax(w_logits_b, dim=1)

                dist_b = torch.distributions.categorical.Categorical(
                    probs=w_probs_b)
                curr_w_idx_b = action_inp_b[:, t, :]
                log_probs_b = torch.transpose(
                    dist_b.log_prob(torch.transpose(curr_w_idx_b, 0, 1)), 0, 1)

                # bl_val_b = baseline(tf.cast(dec_hidden_b, 'float32'))
                # delta_b = ret_b - bl_val_b

                # cost_b = -tf.math.multiply(log_probs_b, delta_b)
                # cost_b = -tf.math.multiply(log_probs_b, ret_b)
                ret_b = torch.reshape(ret_seq_b[:, t],
                                      (BATCH_SIZE, 1)).to(device)
                # alternatively, use torch.mul() but it is overloaded. Might need to try log_probs_b*vec.expand_as(A)
                cost_b = -torch.mul(log_probs_b, ret_b)
                #  log_probs_b*vec.expand_as(A)
                # cost_b = -torch.bmm()   #if we are doing batch multiplication

                loss += cost_b
                # loss_bl += -tf.math.multiply(delta_b, bl_val_b)

                prev_w_idx_b = curr_w_idx_b
                tgt_seq = np.append(tgt_seq,
                                    prev_w_idx_b.data.cpu().numpy(),
                                    axis=1).tolist()

            # calculate cumulative gradients

            # model_vars = encoder.variables + decoder.variables
            loss = loss.mean()
            loss.backward()
            # loss_bl.backward()

            # finally, apply gradient

            l_optimizer.step()
            # bl_optimizer.step()

            # Reset everything for the next episode
            history = history[BATCH_SIZE:]

        if episode % max(BATCH_SIZE, 32) == 0 and batch != None:
            print(">>>>>>>>>>>>>>>>>>>>>>>>>>")
            print("Episode # ", episode)
            print("Samples from episode with rewards > 0: ")
            good_rewards = [(s, a_str, r) for s, _, a_str, r in batch]
            for s, a, r in random.sample(good_rewards,
                                         min(len(good_rewards), 3)):
                print("prev_state: ", s)
                print("actions: ", a)
                print("reward: ", r)
                # print("return: ", get_returns(r, MAX_TARGET_LEN))
            ret_seq_b_np = ret_seq_b.cpu().numpy()
            print("all returns: min=%f, max=%f, median=%f" %
                  (np.min(ret_seq_b_np), np.max(ret_seq_b_np),
                   np.median(ret_seq_b_np)))
            print("avg reward: ", sum(reward_b) / len(reward_b))
            print("avg loss: ", np.mean(loss.cpu().detach().numpy()))
Ejemplo n.º 6
0
    BATCH_SIZE = 32

    EMBEDDING_DIM = get_embedding_dim(USE_GLOVE)

    units = 128

    print("Vocab size: ", len(inp_lang.vocab), len(targ_lang.vocab))

    vocab_inp_size = len(inp_lang.word2idx)
    vocab_tar_size = len(targ_lang.word2idx)

    optimizer = tf.train.AdamOptimizer()
    EPOCHS = 1000000

    input_tensor = [[
        inp_lang.word2idx[token] for token in tokenize_sentence(question)
    ] for question in questions]
    target_tensor = [[
        targ_lang.word2idx[token] for token in tokenize_sentence(answer)
    ] for answer in answers]
    # Calculate max_length of input and output tensor
    # Here, we'll set those to the longest sentence in the dataset
    max_length_inp, max_length_tar = utils.max_length(
        input_tensor), utils.max_length(target_tensor)

    # Padding the input and output tensor to the maximum length
    input_tensor = tf.keras.preprocessing.sequence.pad_sequences(
        input_tensor, maxlen=max_length_inp, padding='post', value=EMPTY_IDX)

    target_tensor = tf.keras.preprocessing.sequence.pad_sequences(
        target_tensor, maxlen=max_length_tar, padding='post', value=EMPTY_IDX)
Ejemplo n.º 7
0
def main():
    tf.enable_eager_execution()

    questions1, answers1 = data.load_conv_text()
    # questions2, answers2 = data.load_opensubtitles_text()

    questions = list(questions1)
    answers = list(answers1)

    inp_lang, targ_lang = LanguageIndex(questions), LanguageIndex(answers)

    input_tensor = [[inp_lang.word2idx[token]
                     for token in tokenize_sentence(question)] for question in questions]
    target_tensor = [[targ_lang.word2idx[token]
                      for token in tokenize_sentence(answer)] for answer in answers]
    max_length_inp, max_length_tar = max_length(
        input_tensor), max_length(target_tensor)
    input_tensor = tf.keras.preprocessing.sequence.pad_sequences(input_tensor,
                                                                 maxlen=max_length_inp,
                                                                 padding='post')
    target_tensor = tf.keras.preprocessing.sequence.pad_sequences(target_tensor,
                                                                  maxlen=max_length_tar,
                                                                  padding='post')
    BUFFER_SIZE = len(input_tensor)
    dataset = tf.data.Dataset.from_tensor_slices(
        (input_tensor, target_tensor)).shuffle(BUFFER_SIZE)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

    model: encoder_decoder.Seq2Seq = load_trained_model(
        BATCH_SIZE, EMBEDDING_DIM, UNITS, tf.train.AdamOptimizer())

    # sentimental_words = ["absolutely","abundant","accept","acclaimed","accomplishment","achievement","action","active","activist","acumen","adjust","admire","adopt","adorable","adored","adventure","affirmation","affirmative","affluent","agree","airy","alive","alliance","ally","alter","amaze","amity","animated","answer","appreciation","approve","aptitude","artistic","assertive","astonish","astounding","astute","attractive","authentic","basic","beaming","beautiful","believe","benefactor","benefit","bighearted","blessed","bliss","bloom","bountiful","bounty","brave","bright","brilliant","bubbly","bunch","burgeon","calm","care","celebrate","certain","change","character","charitable","charming","cheer","cherish","clarity","classy","clean","clever","closeness","commend","companionship","complete","comradeship","confident","connect","connected","constant","content","conviction","copious","core","coupled","courageous","creative","cuddle","cultivate","cure","curious","cute","dazzling","delight","direct","discover","distinguished","divine","donate","each","day","eager","earnest","easy","ecstasy","effervescent","efficient","effortless","electrifying","elegance","embrace","encompassing","encourage","endorse","energized","energy","enjoy","enormous","enthuse","enthusiastic","entirely","essence","established","esteem","everyday","everyone","excited","exciting","exhilarating","expand","explore","express","exquisite","exultant","faith","familiar","family","famous","feat","fit","flourish","fortunate","fortune","freedom","fresh","friendship","full","funny","gather","generous","genius","genuine","give","glad","glow","good","gorgeous","grace","graceful","gratitude","green","grin","group","grow","handsome","happy","harmony","healed","healing","healthful","healthy","heart","hearty","heavenly","helpful","here","highest","good","hold","holy","honest","honor","hug","i","affirm","i","allow","i","am","willing","i","am.","i","can","i","choose","i","create","i","follow","i","know","i","know,","without","a","doubt","i","make","i","realize","i","take","action","i","trust","idea","ideal","imaginative","increase","incredible","independent","ingenious","innate","innovate","inspire","instantaneous","instinct","intellectual","intelligence","intuitive","inventive","joined","jovial","joy","jubilation","keen","key","kind","kiss","knowledge","laugh","leader","learn","legendary","let","go","light","lively","love","loveliness","lucidity","lucrative","luminous","maintain","marvelous","master","meaningful","meditate","mend","metamorphosis","mind-blowing","miracle","mission","modify","motivate","moving","natural","nature","nourish","nourished","novel","now","nurture","nutritious","one","open","openhanded","optimistic","paradise","party","peace","perfect","phenomenon","pleasure","plenteous","plentiful","plenty","plethora","poise","polish","popular","positive","powerful","prepared","pretty","principle","productive","project","prominent","prosperous","protect","proud","purpose","quest","quick","quiet","ready","recognize","refinement","refresh","rejoice","rejuvenate","relax","reliance","rely","remarkable","renew","renowned","replenish","resolution","resound","resources","respect","restore","revere","revolutionize","rewarding","rich","robust","rousing","safe","secure","see","sensation","serenity","shift","shine","show","silence","simple","sincerity","smart","smile","smooth","solution","soul","sparkling","spirit","spirited","spiritual","splendid","spontaneous","still","stir","strong","style","success","sunny","support","sure","surprise","sustain","synchronized","team","thankful","therapeutic","thorough","thrilled","thrive","today","together","tranquil","transform","triumph","trust","truth","unity","unusual","unwavering","upbeat","value","vary","venerate","venture","very","vibrant","victory","vigorous","vision","visualize","vital","vivacious","voyage","wealthy","welcome","well","whole","wholesome","willing","wonder","wonderful","wondrous","xanadu","yes","yippee","young","youth","youthful","zeal","zest","zing","zip"]
    
    sentimental_words = ["good", "excellent", "well"]
    targ_lang_embd = get_GloVe_embeddings(targ_lang.vocab, EMBEDDING_DIM)
    sentimental_words_embd = get_GloVe_embeddings(
        sentimental_words, EMBEDDING_DIM)
    sim_scores = np.dot(sentimental_words_embd, np.transpose(targ_lang_embd))
    print(sim_scores.shape)
    #max_prob_ids = np.argmax(sim_scores, axis=1)
    # print(max_prob_ids)
    # print(targ_lang.word2idx)
    # print(targ_lang.idx2word(max_prob_ids[1]))

    optimizer = tf.train.AdamOptimizer()

    checkpoint_dir = './training_checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, seq2seq=model)

    for episode in range(EPISODES):

        # Start of Episode
        start = time.time()
        total_loss = 0
        for (batch, (inp, targ)) in enumerate(dataset):
            with tf.GradientTape() as tape:

                hidden = tf.zeros((BATCH_SIZE, UNITS))
                enc_hidden = model.encoder(inp, hidden)
                dec_hidden = enc_hidden
                dec_input = tf.expand_dims(
                    [targ_lang.word2idx[BEGIN_TAG]] * BATCH_SIZE, 1)

                loss = 0  # loss for decoder
                pg_loss = 0  # loss for semantic

                result = ''
                for t in range(1, targ.shape[1]):
                    actions = []
                    probs = []
                    rewards = []
                    predictions, dec_hidden = model.decoder(
                        dec_input, dec_hidden)
                    '''
                    predicted_id = tf.argmax(predictions[0]).numpy()
                    if targ_lang.idx2word[predicted_id] == END_TAG:
                        print("result: ", result)
                    else:
                        result += ' ' + targ_lang.idx2word[predicted_id]
                    '''
                    # using teacher forcing
                    dec_input = tf.expand_dims(targ[:, t], 1)
                    for ps in predictions:
                        # action = tf.distributions.Categorical(ps).sample(1)[0]
                        top_k_indices = tf.nn.top_k(ps, TOP_K).indices.numpy()
                        action = np.random.choice(top_k_indices, 1)[0]
                        actions.append(action)
                        prob = ps.numpy()[action]
                        probs.append(prob)
                        reward = np.max(sim_scores[1:, action])
                        print(targ_lang.idx2word[action], reward)
                        # print(targ_lang.idx2word[action], reward)
                        rewards.append(reward)

                        # normalize reward
                        reward_mean = np.mean(rewards)
                        reward_std = np.std(rewards)
                        norm_rewards = [(r - reward_mean) /
                                        reward_std for r in rewards]

                    if targ_lang.idx2word[actions[0]] == END_TAG:
                        print("result: ", result)
                    else:
                        result += ' ' + targ_lang.idx2word[actions[0]]

                    onehot_labels = tf.keras.utils.to_categorical(
                        y=actions, num_classes=len(targ_lang.word2idx))

                    norm_rewards = tf.convert_to_tensor(
                        norm_rewards, dtype="float32")
                    # print(onehot_labels.shape)
                    # print(predictions.shape)
                    loss += model.loss_function(targ[:, t], predictions)
                    # print("------")
                    # print(loss)
                    # print(probs)
                    #pg_loss_cross = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=onehot_labels, logits=targ[:, t]))
                    pg_loss_cross = model.loss_function(
                        targ[:, t],  )
                    # pg_loss_cross = tf.reduce_mean(
                    #     pg_loss_cross * norm_rewards)
                    pg_loss_cross = tf.reduce_mean(
                        pg_loss_cross * rewards)
                    # print(pg_loss_cross)
                    # print("------")
                    # print(pg_loss_cross)
                    pg_loss += pg_loss_cross
                # End of Episode
                # Update policy
                batch_loss = ((loss + pg_loss) / int(targ.shape[1]))
                total_loss += batch_loss
                variables = model.encoder.variables + model.decoder.variables
                gradients = tape.gradient(loss, variables)
                optimizer.apply_gradients(zip(gradients, variables))
                if batch % 10 == 0:
                    print('batch {} training loss {:.4f}'.format(
                        batch, total_loss.numpy()))

        # saving (checkpoint) the model every 100 epochs
        #if (episode + 1) % 100 == 0:
            #checkpoint.save(file_prefix=checkpoint_prefix)

        print('Time taken for {} episode {} sec\n'.format(
            episode, time.time() - start))
Ejemplo n.º 8
0
def encode_sentence(sentence, lang: LanguageIndex):
    return [
        lang.word2idx.get(w, lang.word2idx[lang._unknown_token])
        for w in tokenize_sentence(sentence)
    ]
Ejemplo n.º 9
0
def main():

    env = Environment()
    # print(env.lang.word2idx)

    SAY_HI = "hello"

    targ_lang = env.lang

    vocab_inp_size = len(env.lang.word2idx)
    vocab_tar_size = len(targ_lang.word2idx)

    # GET WORD SCORES
    # sentimental_words = ["absolutely","abundant","accept","acclaimed","accomplishment","achievement","action","active","activist","acumen","adjust","admire","adopt","adorable","adored","adventure","affirmation","affirmative","affluent","agree","airy","alive","alliance","ally","alter","amaze","amity","animated","answer","appreciation","approve","aptitude","artistic","assertive","astonish","astounding","astute","attractive","authentic","basic","beaming","beautiful","believe","benefactor","benefit","bighearted","blessed","bliss","bloom","bountiful","bounty","brave","bright","brilliant","bubbly","bunch","burgeon","calm","care","celebrate","certain","change","character","charitable","charming","cheer","cherish","clarity","classy","clean","clever","closeness","commend","companionship","complete","comradeship","confident","connect","connected","constant","content","conviction","copious","core","coupled","courageous","creative","cuddle","cultivate","cure","curious","cute","dazzling","delight","direct","discover","distinguished","divine","donate","each","day","eager","earnest","easy","ecstasy","effervescent","efficient","effortless","electrifying","elegance","embrace","encompassing","encourage","endorse","energized","energy","enjoy","enormous","enthuse","enthusiastic","entirely","essence","established","esteem","everyday","everyone","excited","exciting","exhilarating","expand","explore","express","exquisite","exultant","faith","familiar","family","famous","feat","fit","flourish","fortunate","fortune","freedom","fresh","friendship","full","funny","gather","generous","genius","genuine","give","glad","glow","good","gorgeous","grace","graceful","gratitude","green","grin","group","grow","handsome","happy","harmony","healed","healing","healthful","healthy","heart","hearty","heavenly","helpful","here","highest","good","hold","holy","honest","honor","hug","i","affirm","i","allow","i","am","willing","i","am.","i","can","i","choose","i","create","i","follow","i","know","i","know,","without","a","doubt","i","make","i","realize","i","take","action","i","trust","idea","ideal","imaginative","increase","incredible","independent","ingenious","innate","innovate","inspire","instantaneous","instinct","intellectual","intelligence","intuitive","inventive","joined","jovial","joy","jubilation","keen","key","kind","kiss","knowledge","laugh","leader","learn","legendary","let","go","light","lively","love","loveliness","lucidity","lucrative","luminous","maintain","marvelous","master","meaningful","meditate","mend","metamorphosis","mind-blowing","miracle","mission","modify","motivate","moving","natural","nature","nourish","nourished","novel","now","nurture","nutritious","one","open","openhanded","optimistic","paradise","party","peace","perfect","phenomenon","pleasure","plenteous","plentiful","plenty","plethora","poise","polish","popular","positive","powerful","prepared","pretty","principle","productive","project","prominent","prosperous","protect","proud","purpose","quest","quick","quiet","ready","recognize","refinement","refresh","rejoice","rejuvenate","relax","reliance","rely","remarkable","renew","renowned","replenish","resolution","resound","resources","respect","restore","revere","revolutionize","rewarding","rich","robust","rousing","safe","secure","see","sensation","serenity","shift","shine","show","silence","simple","sincerity","smart","smile","smooth","solution","soul","sparkling","spirit","spirited","spiritual","splendid","spontaneous","still","stir","strong","style","success","sunny","support","sure","surprise","sustain","synchronized","team","thankful","therapeutic","thorough","thrilled","thrive","today","together","tranquil","transform","triumph","trust","truth","unity","unusual","unwavering","upbeat","value","vary","venerate","venture","very","vibrant","victory","vigorous","vision","visualize","vital","vivacious","voyage","wealthy","welcome","well","whole","wholesome","willing","wonder","wonderful","wondrous","xanadu","yes","yippee","young","youth","youthful","zeal","zest","zing","zip"]
    # sentimental_words = ["good", "excellent", "well"]
    # targ_lang_embd = get_GloVe_embeddings(targ_lang.vocab, EMBEDDING_DIM)
    # sentimental_words_embd = get_GloVe_embeddings(
    #     sentimental_words, EMBEDDING_DIM)
    # sim_scores = np.dot(sentimental_words_embd, np.transpose(targ_lang_embd))
    # print(sim_scores.shape)

    l_optimizer = tf.train.RMSPropOptimizer(0.001)
    bl_optimizer = tf.train.RMSPropOptimizer(0.001)

    # LOAD PRETRAINED MODEL HERE
    # For now...
    # model = load_trained_model(
    #     BATCH_SIZE, EMBEDDING_DIM, UNITS, tf.train.AdamOptimizer())
    '''
    encoder = Encoder(vocab_inp_size, EMBEDDING_DIM,
                      UNITS, batch_sz=BATCH_SIZE, inp_lang=env.lang.vocab)
    decoder = Decoder(vocab_tar_size, EMBEDDING_DIM,
                      UNITS, batch_sz=BATCH_SIZE, targ_lang=targ_lang.vocab)
    '''

    model = Seq2Seq(vocab_inp_size,
                    vocab_tar_size,
                    EMBEDDING_DIM,
                    UNITS,
                    BATCH_SIZE,
                    inp_lang=env.lang,
                    targ_lang=targ_lang,
                    max_length_tar=MAX_TARGET_LEN,
                    use_GloVe=USE_GLOVE,
                    display_result=True,
                    use_beam_search=False)

    import os

    checkpoint_dir = './training_checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt-2")
    checkpoint = tf.train.Checkpoint(optimizer=l_optimizer, seq2seq=model)
    checkpoint.restore(checkpoint_prefix)

    encoder = model.encoder
    decoder = model.decoder

    # LOAD PRETRAINED MODEL HERE
    # For now...
    # encoder = Encoder(vocab_inp_size, EMBEDDING_DIM,
    #                   UNITS, batch_sz=BATCH_SIZE, inp_lang=env.lang.vocab)
    # decoder = Decoder(vocab_tar_size, EMBEDDING_DIM,
    #                   UNITS, batch_sz=BATCH_SIZE, targ_lang=targ_lang.vocab)

    baseline = Baseline(UNITS)

    history = []

    l_optimizer = tf.train.AdamOptimizer()
    bl_optimizer = tf.train.RMSPropOptimizer(0.01)
    batch = None
    avg_rewards = []
    avg_losses = []

    for episode in range(EPISODES):

        # Start of Episode
        env.reset()

        # get first state from the env
        state, _, done = env.step(SAY_HI)

        while not done:  # NOT REALLY USING DONE (Conv_length=1)

            # Run an episode using the TRAINED ENCODER-DECODER model #TODO: test this!!

            init_hidden = initialize_hidden_state(1, UNITS)
            state_inp = [
                env.lang.word2idx[token] for token in tokenize_sentence(state)
            ]
            enc_hidden = encoder(tf.convert_to_tensor([state_inp]),
                                 init_hidden)

            dec_hidden = enc_hidden

            w = BEGIN_TAG
            curr_w_enc = tf.expand_dims(
                [targ_lang.word2idx[tokenize_sentence(w)[0]]], 0)

            #pdb.set_trace() ######################################################################################

            outputs = []
            actions = []
            # words_score = 0
            while w != END_TAG and len(outputs) < MAX_TARGET_LEN:
                w_probs_b, dec_hidden = decoder(curr_w_enc, dec_hidden)
                w_dist = tf.distributions.Categorical(probs=w_probs_b[0])
                w_idx = w_dist.sample(1)
                #pdb.set_trace() ######################################################################################
                actions.append(w_idx)
                # w_idx = tf.argmax(w_probs[0]).numpy()
                w = targ_lang.idx2word[w_idx.numpy()[0]]
                #pdb.set_trace() ######################################################################################

                # NEW: accumulate score of words in full response
                # words_score += np.max(sim_scores[1:, w_idx.numpy()[0]])

                curr_w_enc = tf.expand_dims([targ_lang.word2idx[w]] * 1, 1)
                outputs.append(w)

            # action is a sentence (string)
            action_str = ' '.join(outputs)
            next_state, reward, done = env.step(action_str)
            #pdb.set_trace() ######################################################################################
            # Reward is sentence score + words score. For now, words score is NOT USED
            history.append((state, actions, action_str, reward))
            state = next_state

            #pdb.set_trace() ######################################################################################

            # record history (to be used for gradient updating after the episode is done)
        # End of Episode
        # Update policy
        while len(history) >= BATCH_SIZE:
            batch = history[:BATCH_SIZE]

            state_inp_b, action_encs_b, reward_b, ret_seq_b = zip(*[[
                sentence_to_idxs(state, env.lang), actions_enc_b, reward,
                get_returns(reward, MAX_TARGET_LEN)
            ] for state, actions_enc_b, _, reward in batch])

            #pdb.set_trace() ######################################################################################

            action_encs_b = list(action_encs_b)
            action_encs_b = maybe_pad_sentence(action_encs_b)
            action_encs_b = tf.expand_dims(tf.convert_to_tensor(action_encs_b),
                                           -1)

            ret_mean = np.mean(ret_seq_b)
            ret_std = np.std(ret_seq_b)
            if ret_std == 0:
                ret_seq_b = ret_seq_b - ret_mean
            else:
                ret_seq_b = (ret_seq_b - ret_mean) / ret_std

            ret_seq_b = tf.cast(tf.convert_to_tensor(ret_seq_b), 'float32')

            loss = 0
            loss_bl = 0

            with tf.GradientTape() as l_tape, tf.GradientTape() as bl_tape:
                # accumulate gradient with GradientTape
                init_hidden_b = initialize_hidden_state(BATCH_SIZE, UNITS)

                state_inp_b = maybe_pad_sentence(state_inp_b)
                state_inp_b = tf.convert_to_tensor(state_inp_b)

                enc_hidden_b = encoder(state_inp_b, init_hidden_b)
                dec_hidden_b = enc_hidden_b
                max_sentence_len = action_encs_b.numpy().shape[1]
                prev_w_idx_b = tf.expand_dims(
                    tf.cast(
                        tf.convert_to_tensor([
                            env.lang.word2idx[tokenize_sentence(BEGIN_TAG)[0]]
                        ] * BATCH_SIZE), 'float32'), -1)

                #pdb.set_trace() ######################################################################################

                for t in range(max_sentence_len):

                    bl_val_b = baseline(tf.cast(dec_hidden_b, 'float32'))
                    ret_b = tf.reshape(ret_seq_b[:, t], (BATCH_SIZE, 1))
                    delta_b = ret_b - bl_val_b
                    # print(prev_w_idx_b.shape)
                    w_probs_b, dec_hidden_b = decoder(
                        tf.cast(prev_w_idx_b, dtype='int32'), dec_hidden_b)
                    curr_w_idx_b = action_encs_b[:, t]
                    # w_probs_b = tf.nn.softmax(w_logits_b)
                    dist = tf.distributions.Categorical(probs=w_probs_b)
                    loss_bl += - \
                        tf.multiply(delta_b, bl_val_b)
                    # cost_b = -tf.multiply(
                    #     tf.transpose(dist.log_prob(
                    #         tf.transpose(curr_w_idx_b))), delta_b
                    # )

                    #pdb.set_trace() ######################################################################################
                    cost_b = -tf.multiply(
                        tf.transpose(dist.log_prob(
                            tf.transpose(curr_w_idx_b))), ret_b)
                    # print(cost_b.shape)
                    loss += cost_b

                    prev_w_idx_b = curr_w_idx_b

                    #pdb.set_trace() ######################################################################################

            # calculate cumulative gradients

            #pdb.set_trace() ######################################################################################

            model_vars = encoder.variables + decoder.variables
            grads = l_tape.gradient(loss, model_vars)

            grads_bl = bl_tape.gradient(loss_bl, baseline.variables)

            # finally, apply gradient
            l_optimizer.apply_gradients(zip(grads, model_vars))
            bl_optimizer.apply_gradients(zip(grads_bl, baseline.variables))

            # Reset everything for the next episode
            history = history[BATCH_SIZE:]

        if episode % 20 == 0 and batch != None:
            print(">>>>>>>>>>>>>>>>>>>>>>>>>>")
            print("Episode # ", episode)
            print("Samples from episode with rewards > 0: ")
            good_rewards = [(s, a_str, r) for s, _, a_str, r in batch]
            for s, a, r in random.sample(good_rewards,
                                         min(len(good_rewards), 3)):
                print("prev_state: ", s)
                print("action: ", a)
                print("reward: ", r)
                # print("return: ", get_returns(r, MAX_TARGET_LEN))
            print("all returns: min=%f, max=%f, median=%f" %
                  (np.min(ret_seq_b), np.max(ret_seq_b), np.median(ret_seq_b)))
            avg_reward = sum(reward_b) / len(reward_b)
            avg_rewards.append(avg_reward)
            print("avg reward: ", avg_reward)
            avg_loss = tf.reduce_mean(loss).numpy()
            print("avg loss: ", avg_loss)
            avg_losses.append(avg_loss)
            print("avg grad: ", np.mean(grads[1].numpy()))
            # print("<<<<<<<<<<<<<<<<<<<<<<<<<<")
        if episode % 200 == 0 and batch != None:
            print("Avg rewards: ", avg_rewards)
            print("Avg losses:", avg_losses)