Пример #1
0
    def process_raw_text_into_input(self, raw_text_sentences,
                                    max_conversation_length=5, debug=False,):
        sentences, lens = self.pretrained_prior.process_user_input(
            raw_text_sentences, self.rl_config.max_sentence_length)

        # Remove any sentences of length 0
        sentences = [sent for i, sent in enumerate(sentences) if lens[i] > 0]
        good_raw_sentences = [sent for i, sent in enumerate(
            raw_text_sentences) if lens[i] > 0]
        lens = [l for l in lens if l > 0]

        # Trim conversation to max length
        sentences = sentences[-max_conversation_length:]
        lens = lens[-max_conversation_length:]
        good_raw_sentences = good_raw_sentences[-max_conversation_length:]
        convo_length = len(sentences)

        # Convert to torch variables
        input_sentences = to_var(torch.LongTensor(sentences))
        input_sentence_length = to_var(torch.LongTensor(lens))
        input_conversation_length = to_var(torch.LongTensor([convo_length]))

        if debug:
            print('\n**Conversation history:**')
            for sent in sentences:
                print(self.vocab.decode(list(sent)))

        return (input_sentences, input_sentence_length, 
                input_conversation_length)
Пример #2
0
    def init_h(self, batch_size=None, hidden=None):
        """Return RNN initial state"""
        if hidden is not None:
            return hidden

        if self.use_lstm:
            return (to_var(
                torch.zeros(self.num_layers * self.num_directions, batch_size,
                            self.hidden_size)),
                    to_var(
                        torch.zeros(self.num_layers * self.num_directions,
                                    batch_size, self.hidden_size)))
        else:
            return to_var(
                torch.zeros(self.num_layers * self.num_directions, batch_size,
                            self.hidden_size))
Пример #3
0
    def init_h(self, batch_size=None, zero=True, hidden=None):
        """Return RNN initial state"""
        if hidden is not None:
            return hidden

        if self.use_lstm:
            # (h, c)
            return (to_var(
                torch.zeros(self.num_layers, batch_size, self.hidden_size)),
                    to_var(
                        torch.zeros(self.num_layers, batch_size,
                                    self.hidden_size)))
        else:
            # h
            return to_var(
                torch.zeros(self.num_layers, batch_size, self.hidden_size))
Пример #4
0
    def compute_rewards(self,
                        conversations,
                        rewards_lst,
                        reward_weights,
                        gamma=0.0):
        supported = {
            'reward_question', 'reward_you', 'reward_toxicity',
            'reward_bot_deepmoji', 'reward_user_deepmoji',
            'reward_conversation_repetition', 'reward_utterance_repetition',
            'reward_infersent_coherence', 'reward_deepmoji_coherence',
            'reward_word2vec_coherence', 'reward_bot_response_length',
            'reward_word_similarity', 'reward_USE_similarity'
        }

        episode_len = self.config.episode_len
        num_convs = self.config.rl_batch_size
        combined_rewards = np.zeros((num_convs, episode_len))

        for r, w in zip(rewards_lst, reward_weights):
            if r not in supported: raise NotImplementedError()
            reward_func = getattr(hrl_rewards, r)
            rewards = reward_func(conversations)
            discounted = discount(rewards, gamma)
            normalized = normalizeZ(discounted)
            combined_rewards += float(w) * normalized

            self.rewards_history[r].append(rewards.mean().item())

        # [num_convs, num_actions] = [rl_batch_size, episode_len]
        return to_var(torch.FloatTensor(combined_rewards))
Пример #5
0
    def extract_sentence_data(self, batch):
        with torch.no_grad():
            # Extract batch info
            actions = to_var(torch.LongTensor(batch['action']))  # [batch_size]
            action_lens = batch['action_lens']

            conversations = [
                np.concatenate((conv, np.atleast_2d(batch['action'][i])))
                for i, conv in enumerate(batch['state'])
            ]
            sent_lens = [
                np.concatenate((lens, np.atleast_1d(batch['action_lens'][i])))
                for i, lens in enumerate(batch['state_lens'])
            ]
            target_conversations = [conv[1:] for conv in conversations]
            targets = [sent for conv in target_conversations for sent in conv]
            targets = to_var(torch.LongTensor(targets))
            conv_lens = [len(c) - 1 for c in conversations]
            conv_lens = to_var(torch.LongTensor(conv_lens))

            # Compute non-variational inputs
            hred_convs = [conv[:-1] for conv in conversations]
            hred_sent_lens = np.concatenate([l[:-1] for l in sent_lens])
            hred_sent_lens = to_var(torch.LongTensor(hred_sent_lens))
            hred_sentences = [sent for conv in hred_convs for sent in conv]
            hred_sentences = to_var(torch.LongTensor(hred_sentences))

            # Compute variational inputs
            sent_lens = np.concatenate([l for l in sent_lens])
            sent_lens = to_var(torch.LongTensor(sent_lens))
            sentences = [sent for conv in conversations for sent in conv]
            sentences = to_var(torch.LongTensor(sentences))

            return (actions, action_lens, sentences, sent_lens, hred_sentences,
                    hred_sent_lens, targets, conv_lens)
Пример #6
0
def dynamically_assess_context_inputs(gen_response, botmoji, botsent, vocab, 
                                      config):
    gen_response = gen_response.view(-1,30).cpu().numpy()

    # Translate tokens to words and detokenize
    decoded_response = [vocab.decode(list(g)) for g in gen_response] # needs to be higher dimensional?
    decoded_response = [detokenize(d) for d in decoded_response]

    # Assess DeepMoji and Infersent on text
    try:
        infersent_sentences = to_var(torch.FloatTensor(
            [botsent.encode(s) for s in decoded_response]))

        blank_deepmoji = [1.0 / config.emo_output_size] * config.emo_output_size
        emoji_sentences = to_var(torch.FloatTensor(
            [botmoji.encode(s) if s != '' else blank_deepmoji for s in decoded_response]))
    except Exception as e:
        print("Error in dynamic context iputs:")
        print(str(e))
    return torch.cat((emoji_sentences, infersent_sentences), 1)
Пример #7
0
    def run_seq2seq_model(self, q_net, input_conversations, sent_lens,
                          target_conversations, conv_lens):
        # Prepare the batch
        sentences = [sent for conv in input_conversations for sent in conv]
        targets = [sent for conv in target_conversations for sent in conv]

        if not (np.all(np.isfinite(sentences)) and np.all(np.isfinite(targets))
                and np.all(np.isfinite(sent_lens))):
            print("Input isn't finite")

        sentences = to_var(torch.LongTensor(sentences))
        targets = to_var(torch.LongTensor(targets))
        sent_lens = to_var(torch.LongTensor(sent_lens))

        # Run Q network
        q_outputs = q_net.model(sentences,
                                sent_lens,
                                conv_lens,
                                targets,
                                rl_mode=True)
        return q_outputs[0]  # [num_sentences, max_sentence_len, vocab_size]
Пример #8
0
    def embed(self, x):
        """word index: [batch_size] => word vectors: [batch_size, hidden_size]"""

        if self.training and self.word_drop > 0.0:
            if random.random() < self.word_drop:
                embed = self.embedding(to_var(x.data.new([UNK_ID] * x.size(0))))
            else:
                embed = self.embedding(x)
        else:
            embed = self.embedding(x)

        return embed
Пример #9
0
    def run_model_on_sentences(self, bot, batch_tensors):
        with torch.no_grad():
            (actions, action_lens, sentences, sent_lens, hred_sentences,
             hred_sent_lens, targets, conv_lens) = batch_tensors
            if bot.config.model not in VariationalModels:
                sentences = hred_sentences
                sent_lens = hred_sent_lens

            # Run model
            outputs = bot.solver.model(sentences,
                                       sent_lens,
                                       conv_lens,
                                       targets,
                                       rl_mode=True)
            logits = outputs[0]

            # Index to get only output values for actions taken (last sentence
            # in each conversation)
            start = torch.cumsum(
                torch.cat(
                    (to_var(conv_lens.data.new(1).zero_()), conv_lens[:-1])),
                0)
            action_logits = torch.stack([
                logits[s + l - 1, :, :]
                for s, l in zip(start.data.tolist(), conv_lens.data.tolist())
            ], 0)  # [num_sentences, max_sent_len, vocab_size]

            # Limit by actual sentence length (remove padding) and flatten into
            # long list of words
            word_logits = torch.cat(
                [action_logits[i, :l, :] for i, l in enumerate(action_lens)],
                0)  # [total words, vocab_size]
            word_actions = torch.cat(
                [actions[i, :l] for i, l in enumerate(action_lens)],
                0)  # [total words]

            # Take softmax to get probability distribution
            # [total_words, vocab_size]
            word_probs = torch.nn.functional.softmax(word_logits, 1)

            # Extract q values corresponding to actions taken
            relevant_words = word_probs.gather(
                1, word_actions.unsqueeze(1)).squeeze()  # [total words]

            return relevant_words
Пример #10
0
    def duplicate_context_for_beams(self, sentences, sent_lens, conv_lens,
                                    beams):
        conv_lens = conv_lens.repeat(len(beams))
        # [beam_size * sentences, sentence_len]
        if len(sentences) > 1:
            targets = torch.cat(
                [torch.cat([sentences[1:,:], beams[i,:].unsqueeze(0)], 0) 
                for i in range(len(beams))], 0)
        else:
            targets = beams

        # HRED
        if self.rl_config.model not in VariationalModels:
            sent_lens = sent_lens.repeat(len(beams))
            return sentences, sent_lens, conv_lens, targets
        
        # VHRED, VHCR
        new_sentences = torch.cat(
            [torch.cat([sentences, beams[i,:].unsqueeze(0)], 0) 
            for i in range(len(beams))], 0)
        new_len = to_var(torch.LongTensor([self.rl_config.max_sentence_length]))
        sent_lens = torch.cat(
            [torch.cat([sent_lens, new_len], 0) for i in range(len(beams))])
        return new_sentences, sent_lens, conv_lens, targets
Пример #11
0
    def generate(self, context, sentence_length, n_context, 
                 extra_context_inputs=None, botmoji=None, botsent=None, 
                 vocab=None):
        # context: [batch_size, n_context, seq_len]
        batch_size = context.size(0)
        # n_context = context.size(1)
        samples = []

        # Run for context

        conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size]))
        # conv_mu_prior, conv_var_prior = self.conv_prior()
        # z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps

        context_inputs_list = []
        for i in range(n_context):
            # encoder_outputs: [batch_size, seq_len, hidden_size * direction]
            # encoder_hidden: [num_layers * direction, batch_size, hidden_size]
            encoder_outputs, raw_encoder_hidden = self.encoder(context[:, i, :],
                                                               sentence_length[:, i])

            # encoder_hidden: [batch_size, num_layers * direction * hidden_size]
            context_inputs_2d = raw_encoder_hidden.transpose(
                1, 0).contiguous().view(batch_size, -1)

            if self.config.context_input_only:
                context_inputs_2d = torch.cat(
                    (context_inputs_2d, extra_context_inputs), 1)
    
            context_inputs_list.append(context_inputs_2d)

        context_inputs = torch.stack(context_inputs_list, 1)
        (context_inference_outputs, 
         context_inference_hidden) = self.context_inference(
             context_inputs, to_var(torch.LongTensor([n_context] * batch_size)))
        context_inference_hidden = context_inference_hidden.transpose(
            1, 0).contiguous().view(batch_size, -1)
        conv_mu_posterior, conv_var_posterior = self.conv_posterior(context_inference_hidden)
        z_conv = conv_mu_posterior + torch.sqrt(conv_var_posterior) * conv_eps

        context_init = self.z_conv2context(z_conv).view(
            self.config.num_layers, batch_size, self.config.context_size)

        context_hidden = context_init
        for i in range(n_context):
            # encoder_outputs: [batch_size, seq_len, hidden_size * direction]
            # encoder_hidden: [num_layers * direction, batch_size, hidden_size]
            encoder_outputs, raw_encoder_hidden = self.encoder(context[:, i, :],
                                                               sentence_length[:, i])

            # encoder_hidden: [batch_size, num_layers * direction *
            context_inputs_2d = raw_encoder_hidden.transpose(
                1, 0).contiguous().view(batch_size, -1)

            if self.config.context_input_only:
                context_inputs_2d = torch.cat(
                    (context_inputs_2d, extra_context_inputs), 1)

            context_inputs_list.append(context_inputs_2d)
            # context_outputs: [batch_size, 1, context_hidden_size * direction]
            # context_hidden: [num_layers * direction, batch_size, context_hidden_size]
            context_outputs, context_hidden = self.context_encoder.step(
                torch.cat([context_inputs_2d, z_conv], 1), context_hidden)

        # Run for generation
        for j in range(self.config.n_sample_step):
            # context_outputs: [batch_size, context_hidden_size * direction]
            context_outputs = context_outputs.squeeze(1)

            mu_prior, var_prior = self.sent_prior(context_outputs, z_conv)
            eps = to_var(torch.randn((batch_size, self.config.z_sent_size)))
            z_sent = mu_prior + torch.sqrt(var_prior) * eps

            latent_context = torch.cat([context_outputs, z_sent, z_conv], 1)
            decoder_init = self.context2decoder(latent_context)
            decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size)

            if self.config.sample:
                prediction = self.decoder(None, decoder_init, decode=True)
                p = prediction.data.cpu().numpy()
                length = torch.from_numpy(np.where(p == EOS_ID)[1])
            else:
                prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)
                # prediction: [batch_size, seq_len]
                prediction = prediction[:, 0, :]
                # length: [batch_size]
                length = [l[0] for l in length]
                length = to_var(torch.LongTensor(length))

            samples.append(prediction)

            encoder_outputs, raw_encoder_hidden = self.encoder(prediction,
                                                               length)

            context_inputs_2d = raw_encoder_hidden.transpose(
                1, 0).contiguous().view(batch_size, -1)

            # Dynamically assess the DeepMoji and Infersent predictions on
            # generated text
            if self.config.context_input_only:
                dynamic_context_inputs = dynamically_assess_context_inputs(
                    prediction, botmoji, botsent, vocab, self.config)
                context_inputs_2d = torch.cat(
                    (context_inputs_2d, dynamic_context_inputs), 1)

            context_outputs, context_hidden = self.context_encoder.step(
                torch.cat([context_inputs_2d, z_conv], 1), context_hidden)

        samples = torch.stack(samples, 1)
        return samples
Пример #12
0
    def get_target_q_values(self, batch, prior_rewards=None):
        rewards = to_var(torch.FloatTensor(batch['rewards']))  # [batch_size]
        not_done = to_var(torch.FloatTensor(1 - batch['done']))  # [batch_size]
        self.sampled_reward_batch_history.append(torch.sum(rewards).item())

        # Prepare inputs to target Q network. Append a blank sentence to get
        # best response at next utterance to user input. (Next state
        # includes user input).
        blank_sentence = np.zeros((1, self.config.max_sentence_length))
        next_state_convs = [
            np.concatenate((conv, blank_sentence))
            for conv in batch['next_state']
        ]
        next_state_lens = [
            np.concatenate((lens, [1])) for lens in batch['next_state_lens']
        ]
        next_targets = [conv[1:] for conv in next_state_convs]
        next_conv_lens = [len(c) - 1 for c in next_state_convs]
        if self.config.model not in VariationalModels:
            next_state_convs = [conv[:-1] for conv in next_state_convs]
            next_state_lens = np.concatenate([l[:-1] for l in next_state_lens])
        else:
            next_state_lens = np.concatenate([l for l in next_state_lens])
        next_conv_lens = to_var(torch.LongTensor(next_conv_lens))

        # [monte_carlo_count, num_sentences, max sent len, vocab size]
        _mc_target_q_values = [[]] * self.config.monte_carlo_count
        for t in range(self.config.monte_carlo_count):
            # Run target Q network. Output is size:
            # [num_sentences, max sent len, vocab size]
            if self.config.monte_carlo_count == 1:
                # In this setting, we don't use dropout out at inference time at all
                all_target_q_values = self.run_seq2seq_model(
                    self.target_q_net, next_state_convs, next_state_lens,
                    next_targets, next_conv_lens)
            else:
                # In this setting, each time we draw a new dropout mask (at inference time)
                all_target_q_values = self.run_seq2seq_model(
                    self.target_q_net, next_state_convs, next_state_lens,
                    next_targets, next_conv_lens)

            # Target indexing: last sentence is a blank to get value of next
            # response. Second last is the user response. 3rd last is models own
            # actions. Note that targets begin at the 2nd word of each sentence.
            start_t = torch.cumsum(
                torch.cat((to_var(next_conv_lens.data.new(1).zero_()),
                           next_conv_lens[:-1])), 0)
            conv_target_q_values = torch.stack([
                all_target_q_values[s + l - 3, 1:, :] for s, l in zip(
                    start_t.data.tolist(), next_conv_lens.data.tolist())
            ], 0)  # Dimension [num_sentences, max_sent_len - 1, vocab_size]

            # At the end of a sentence, want value of starting a new response
            # after user's response. So index into first word of last blank
            # sentence that was appended to the end of the conversation.
            next_response_targets = torch.stack([
                all_target_q_values[s + l - 1, 0, :] for s, l in zip(
                    start_t.data.tolist(), next_conv_lens.data.tolist())
            ], 0)
            next_response_targets = torch.reshape(
                next_response_targets, [self.config.rl_batch_size, 1, -1
                                        ])  # [num_sentences, 1, vocab_size]
            conv_target_q_values = torch.cat(
                [conv_target_q_values, next_response_targets],
                1)  # [num_sentences, max_sent_len, vocab_size]

            # Limit target Q values by conversation length
            limit_conv_targets = [
                conv_target_q_values[i, :l, :]
                for i, l in enumerate(batch['action_lens'])
            ]

            if self.config.psi_learning:
                # Target is r + gamma * log sum_a' exp(Q_target(s', a'))
                conv_max_targets = [
                    torch.distributions.utils.log_sum_exp(c)
                    for c in limit_conv_targets
                ]
                target_q_values = torch.cat([
                    rewards[i] + not_done[i] * self.config.gamma * c.squeeze()
                    for i, c in enumerate(conv_max_targets)
                ], 0)  # [total words]
            else:
                # Target is r + gamma * max_a' Q_target(s',a'). Reward and done are
                # at the level of conversation, so add and multiply in before
                # flattening and taking max.
                word_target_q_values = torch.cat([
                    rewards[i] + not_done[i] * self.config.gamma * c
                    for i, c in enumerate(limit_conv_targets)
                ], 0)  # [total words, vocab_size]
                target_q_values, _ = word_target_q_values.max(1)

            _mc_target_q_values[t] = target_q_values
        mc_target_q_values = torch.stack(_mc_target_q_values, 0)

        min_target_q_values, _ = mc_target_q_values.min(0)

        if self.config.kl_control:
            min_target_q_values += prior_rewards

        return min_target_q_values
Пример #13
0
    def get_q_values(self, batch):
        """update where states are whole conversations which
        each have several sentences, and actions are a sentence (series of 
        words). Q values are per word. Target Q values are over the next word 
        in the sentence, or, if at the end of the sentence, the first word in a
        new sentence after the user response.
        """
        actions = to_var(torch.LongTensor(batch['action']))  # [batch_size]

        # Prepare inputs to Q network
        conversations = [
            np.concatenate((conv, np.atleast_2d(batch['action'][i])))
            for i, conv in enumerate(batch['state'])
        ]
        sent_lens = [
            np.concatenate((lens, np.atleast_1d(batch['action_lens'][i])))
            for i, lens in enumerate(batch['state_lens'])
        ]
        target_conversations = [conv[1:] for conv in conversations]
        conv_lens = [len(c) - 1 for c in conversations]
        if self.config.model not in VariationalModels:
            conversations = [conv[:-1] for conv in conversations]
            sent_lens = np.concatenate([l[:-1] for l in sent_lens])
        else:
            sent_lens = np.concatenate([l for l in sent_lens])
        conv_lens = to_var(torch.LongTensor(conv_lens))

        # Run Q network. Will produce [num_sentences, max sent len, vocab size]
        all_q_values = self.run_seq2seq_model(self.q_net, conversations,
                                              sent_lens, target_conversations,
                                              conv_lens)

        # Index to get only q values for actions taken (last sentence in each
        # conversation)
        start_q = torch.cumsum(
            torch.cat((to_var(conv_lens.data.new(1).zero_()), conv_lens[:-1])),
            0)
        conv_q_values = torch.stack([
            all_q_values[s + l - 1, :, :]
            for s, l in zip(start_q.data.tolist(), conv_lens.data.tolist())
        ], 0)  # [num_sentences, max_sent_len, vocab_size]

        # Limit by actual sentence length (remove padding) and flatten into
        # long list of words
        word_q_values = torch.cat([
            conv_q_values[i, :l, :] for i, l in enumerate(batch['action_lens'])
        ], 0)  # [total words, vocab_size]
        word_actions = torch.cat(
            [actions[i, :l] for i, l in enumerate(batch['action_lens'])],
            0)  # [total words]

        # Extract q values corresponding to actions taken
        q_values = word_q_values.gather(
            1, word_actions.unsqueeze(1)).squeeze()  # [total words]
        """ Compute KL metrics """
        prior_rewards = None

        # Get probabilities from policy network
        q_dists = torch.nn.functional.softmax(word_q_values, 1)
        q_probs = q_dists.gather(1, word_actions.unsqueeze(1)).squeeze()

        with torch.no_grad():
            # Run pretrained prior network.
            # [num_sentences, max sent len, vocab size]
            all_prior_logits = self.run_seq2seq_model(self.pretrained_prior,
                                                      conversations, sent_lens,
                                                      target_conversations,
                                                      conv_lens)

            # Get relevant actions. [num_sentences, max_sent_len, vocab_size]
            conv_prior = torch.stack([
                all_prior_logits[s + l - 1, :, :]
                for s, l in zip(start_q.data.tolist(), conv_lens.data.tolist())
            ], 0)

            # Limit by actual sentence length (remove padding) and flatten.
            # [total words, vocab_size]
            word_prior_logits = torch.cat([
                conv_prior[i, :l, :]
                for i, l in enumerate(batch['action_lens'])
            ], 0)

            # Take the softmax
            prior_dists = torch.nn.functional.softmax(word_prior_logits, 1)

            kl_div = F.kl_div(q_dists.log(), prior_dists, reduce=False)

            # [total words]
            prior_probs = prior_dists.gather(
                1, word_actions.unsqueeze(1)).squeeze()
            logp_logq = prior_probs.log() - q_probs.log()

            if self.config.model_averaging:
                model_avg_sentences = batch['model_averaged_probs']

                # Convert to tensors and flatten into [num_words]
                word_model_avg = torch.cat([
                    to_var(torch.FloatTensor(m)) for m in model_avg_sentences
                ], 0)

                # Compute KL from model-averaged prior
                prior_rewards = word_model_avg.log() - q_probs.log()

                # Clip because KL should never be negative, so because we
                # are subtracting KL, rewards should never be positive
                prior_rewards = torch.clamp(prior_rewards, max=0.0)

            elif self.config.kl_control and self.config.kl_calc == 'integral':
                # Note: we reward the negative KL divergence to ensure the
                # RL model stays close to the prior
                prior_rewards = -1.0 * torch.sum(kl_div, dim=1)
            elif self.config.kl_control:
                prior_rewards = logp_logq

            if self.config.kl_control:
                prior_rewards = prior_rewards * self.config.kl_weight_c
                self.kl_reward_batch_history.append(
                    torch.sum(prior_rewards).item())

            # Track all metrics
            self.kl_div_batch_history.append(torch.mean(kl_div).item())
            self.logp_batch_history.append(
                torch.mean(prior_probs.log()).item())
            self.logp_logq_batch_history.append(torch.mean(logp_logq).item())

        return q_values, prior_rewards
Пример #14
0
 def init_token(self, batch_size, SOS_ID=SOS_ID):
     """Get Variable of <SOS> Index (batch_size)"""
     x = to_var(torch.LongTensor([SOS_ID] * batch_size))
     return x
Пример #15
0
    def beam_decode(self,
                    init_h=None,
                    encoder_outputs=None,
                    input_valid_length=None,
                    decode=False):
        """
        Args:
            encoder_outputs (Variable, FloatTensor): [batch_size, source_length, hidden_size]
            input_valid_length (Variable, LongTensor): [batch_size] (optional)
            init_h (variable, FloatTensor): [batch_size, hidden_size] (optional)
        Return:
            out   : [batch_size, seq_len]
        """
        batch_size = self.batch_size(h=init_h)

        # [batch_size x beam_size]
        x = self.init_token(batch_size * self.beam_size, SOS_ID)

        # [num_layers, batch_size x beam_size, hidden_size]
        h = self.init_h(batch_size, hidden=init_h).repeat(1, self.beam_size, 1)

        # batch_position [batch_size]
        #   [0, beam_size, beam_size * 2, .., beam_size * (batch_size-1)]
        #   Points where batch starts in [batch_size x beam_size] tensors
        #   Ex. position_idx[5]: when 5-th batch starts
        batch_position = to_var(
            torch.arange(0, batch_size).long() * self.beam_size)

        # Initialize scores of sequence
        # [batch_size x beam_size]
        # Ex. batch_size: 5, beam_size: 3
        # [0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf]
        score = torch.ones(batch_size * self.beam_size) * -float('inf')
        score.index_fill_(0,
                          torch.arange(0, batch_size).long() * self.beam_size,
                          0.0)
        score = to_var(score)

        # Initialize Beam that stores decisions for backtracking
        beam = Beam(batch_size, self.hidden_size, self.vocab_size,
                    self.beam_size, self.max_unroll, batch_position)

        for i in range(self.max_unroll):

            # x: [batch_size x beam_size]; (token index)
            # =>
            # out: [batch_size x beam_size, vocab_size]
            # h: [num_layers, batch_size x beam_size, hidden_size]
            out, h = self.forward_step(x,
                                       h,
                                       encoder_outputs=encoder_outputs,
                                       input_valid_length=input_valid_length)
            # log_prob: [batch_size x beam_size, vocab_size]
            log_prob = F.log_softmax(out, dim=1)

            # [batch_size x beam_size]
            # => [batch_size x beam_size, vocab_size]
            score = score.view(-1, 1) + log_prob

            # Select `beam size` transitions out of `vocab size` combinations

            # [batch_size x beam_size, vocab_size]
            # => [batch_size, beam_size x vocab_size]
            # Cutoff and retain candidates with top-k scores
            # score: [batch_size, beam_size]
            # top_k_idx: [batch_size, beam_size]
            #       each element of top_k_idx [0 ~ beam x vocab)

            score, top_k_idx = score.view(batch_size, -1).topk(self.beam_size,
                                                               dim=1)

            # Get token ids with remainder after dividing by top_k_idx
            # Each element is among [0, vocab_size)
            # Ex. Index of token 3 in beam 4
            # (4 * vocab size) + 3 => 3
            # x: [batch_size x beam_size]
            x = (top_k_idx % self.vocab_size).view(-1)

            # top-k-pointer [batch_size x beam_size]
            #       Points top-k beam that scored best at current step
            #       Later used as back-pointer at backtracking
            #       Each element is beam index: 0 ~ beam_size
            #                     + position index: 0 ~ beam_size x (batch_size-1)
            beam_idx = top_k_idx / self.vocab_size  # [batch_size, beam_size]
            top_k_pointer = (beam_idx + batch_position.unsqueeze(1)).view(-1)

            # Select next h (size doesn't change)
            # [num_layers, batch_size * beam_size, hidden_size]
            h = h.index_select(1, top_k_pointer)

            # Update sequence scores at beam
            beam.update(score.clone(), top_k_pointer, x)  # , h)

            # Erase scores for EOS so that they are not expanded
            # [batch_size, beam_size]
            eos_idx = x.data.eq(EOS_ID).view(batch_size, self.beam_size)
            if eos_idx.nonzero().dim() > 0:
                score.data.masked_fill_(eos_idx, -float('inf'))

        # prediction ([batch, k, max_unroll])
        #     A list of Tensors containing predicted sequence
        # final_score [batch, k]
        #     A list containing the final scores for all top-k sequences
        # length [batch, k]
        #     A list specifying the length of each sequence in the top-k candidates
        # prediction, final_score, length = beam.backtrack()
        prediction, final_score, length = beam.backtrack()

        return prediction, final_score, length
Пример #16
0
    def generate(self, context, sentence_length, n_context, 
                 extra_context_inputs=None, botsent=None, botmoji=None, 
                 vocab=None):
        # context: [batch_size, n_context, seq_len]
        batch_size = context.size(0)
        # n_context = context.size(1)
        samples = []

        # Run for context
        context_hidden=None
        for i in range(n_context):
            # encoder_outputs: [batch_size, seq_len, hidden_size * direction]
            # encoder_hidden: [num_layers * direction, batch_size, hidden_size]
            encoder_outputs, raw_encoder_hidden = self.encoder(context[:, i, :],
                                                           sentence_length[:, i])

            context_inputs_2d = raw_encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)

            if self.config.context_input_only:
                context_inputs_2d = torch.cat(
                    (context_inputs_2d, extra_context_inputs), 1)

            # context_outputs: [batch_size, 1, context_hidden_size * direction]
            # context_hidden: [num_layers * direction, batch_size, context_hidden_size]
            context_outputs, context_hidden = self.context_encoder.step(context_inputs_2d,
                                                                        context_hidden)

        # Run for generation
        for j in range(self.config.n_sample_step):
            # context_outputs: [batch_size, context_hidden_size * direction]
            context_outputs = context_outputs.squeeze(1)
            decoder_init = self.context2decoder(context_outputs)
            decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size)

            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)
            # prediction: [batch_size, seq_len]
            prediction = prediction[:, 0, :]
            # length: [batch_size]
            length = [l[0] for l in length]
            length = to_var(torch.LongTensor(length))
            samples.append(prediction)

            encoder_outputs, raw_encoder_hidden = self.encoder(prediction,
                                                               length)

            context_inputs_2d = raw_encoder_hidden.transpose(
                1, 0).contiguous().view(batch_size, -1)
            
            # Dynamically assess the DeepMoji and Infersent predictions on
            # generated text
            if self.config.context_input_only:
                dynamic_context_inputs = dynamically_assess_context_inputs(
                    prediction, botmoji, botsent, vocab, self.config)
                context_inputs_2d = torch.cat(
                    (context_inputs_2d, dynamic_context_inputs), 1)

            context_outputs, context_hidden = self.context_encoder.step(
                context_inputs_2d, context_hidden)

        samples = torch.stack(samples, 1)
        return samples
Пример #17
0
    def generate_response_to_input(self, raw_text_sentences, 
                                   max_conversation_length=5,
                                   sample_by='priority', emojize=True,
                                   debug=True):
        with torch.no_grad():
            (input_sentences, input_sent_lens, 
             input_conv_lens) = self.process_raw_text_into_input(
                raw_text_sentences, debug=debug,
                max_conversation_length=max_conversation_length)

            # Initialize a tensor for beams
            beams = to_var(torch.LongTensor(
                np.ones((self.rl_config.beam_size, 
                        self.rl_config.max_sentence_length))))

            # Create a batch with the context duplicated for each beam
            (sentences, sent_lens, 
            conv_lens, targets) = self.duplicate_context_for_beams(
                input_sentences, input_sent_lens, input_conv_lens, beams)

            # Continuously feed beam sentences into networks to sample the next 
            # best word, add that to the beam, and continue
            for i in range(self.rl_config.max_sentence_length):
                # Run both models to obtain logits
                prior_output = self.pretrained_prior.model(
                    sentences, sent_lens, conv_lens, targets, rl_mode=True)
                all_prior_logits = prior_output[0]
                
                q_output = self.q_net.model(
                    sentences, sent_lens, conv_lens, targets, rl_mode=True)
                all_q_logits = q_output[0]

                # Select only those logits for next word
                q_logits = all_q_logits[:, i, :].squeeze()
                prior_logits = all_prior_logits[:, i, :].squeeze()

                # Get prior distribution for next word in each beam
                prior_dists = torch.nn.functional.softmax(prior_logits, 1)

                for b in range(self.rl_config.beam_size):
                    # Sample from the prior bcq_n times for each beam
                    dist = torch.distributions.Categorical(prior_dists[b,:])
                    sampled_idxs = dist.sample_n(self.rl_config.bcq_n)

                    # Select sample with highest q value
                    q_vals = torch.stack(
                        [q_logits[b, idx] for idx in sampled_idxs])
                    _, best_word_i = torch.max(q_vals, 0) 
                    best_word = sampled_idxs[best_word_i]

                    # Update beams
                    beams[b, i] = best_word

                (sentences, sent_lens, 
                 conv_lens, targets) = self.duplicate_context_for_beams(
                    input_sentences, input_sent_lens, input_conv_lens, beams)
            
            generated_sentences = beams.cpu().numpy()

        if debug:
            print('\n**All generated responses:**')
            for gen in generated_sentences:
                print(detokenize(self.vocab.decode(list(gen))))
        
        gen_response = self.pretrained_prior.select_best_generated_response(
            generated_sentences, sample_by, beam_size=self.rl_config.beam_size)

        decoded_response = self.vocab.decode(list(gen_response))
        decoded_response = detokenize(decoded_response)

        if emojize:
            inferred_emojis = self.pretrained_prior.botmoji.emojize_text(
                raw_text_sentences[-1], 5, 0.07)
            decoded_response = inferred_emojis + " " + decoded_response
        
        return decoded_response
Пример #18
0
 def compute_bow_loss(self, target_conversations):
     target_bow = np.stack([to_bow(sent, self.config.vocab_size) for conv in target_conversations for sent in conv], axis=0)
     target_bow = to_var(torch.FloatTensor(target_bow))
     bow_logits = self.bow_predict(self.bow_h(self.z_sent))
     bow_loss = bag_of_words_loss(bow_logits, target_bow)
     return bow_loss
Пример #19
0
    def forward(self, input_sentences, input_sentence_length,
                input_conversation_length, target_sentences, decode=False,
                extra_context_inputs=None, rl_mode=False):
        """
        Args:
            input_sentences: (Variable, LongTensor) [num_sentences, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        num_sentences = input_sentences.size(0)
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences, max_source_length, hidden_size * direction]
        # encoder_hidden: [num_layers * direction, num_sentences, hidden_size]
        encoder_outputs, raw_encoder_hidden = self.encoder(
            input_sentences, input_sentence_length)

        # encoder_hidden: [num_sentences, num_layers * direction * hidden_size]
        context_inputs_2d = raw_encoder_hidden.transpose(
            1, 0).contiguous().view(num_sentences, -1)

        if self.config.context_input_only:
            context_inputs_2d = torch.cat(
                (context_inputs_2d, extra_context_inputs), 1)

        # pad and pack encoder_hidden
        start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                                        input_conversation_length[:-1])), 0)

        # encoder_hidden: [batch_size, max_len, num_layers * direction * hidden_size]
        context_inputs = torch.stack([pad(context_inputs_2d.narrow(0, s, l), max_len)
                                      for s, l in zip(start.data.tolist(),
                                                      input_conversation_length.data.tolist())], 0)

        # context_outputs: [batch_size, max_len, context_size]
        context_outputs, context_last_hidden = self.context_encoder(
            context_inputs, input_conversation_length)

        # flatten outputs
        # context_outputs: [num_sentences, context_size]
        context_outputs = torch.cat([context_outputs[i, :l, :]
                                     for i, l in enumerate(input_conversation_length.data)])

        # Stop gradients from flowing from discriminator if only using input
        if self.config.context_input_only:
            discriminator_input = context_outputs.detach()
        else:
            discriminator_input = context_outputs

        # Predict emojis using discriminator.
        emoji_preds = None
        if self.config.emotion:
            emoji_preds = self.context2emoji(discriminator_input)

        # Predict infersent using discriminator.
        infersent_preds = None
        if self.config.infersent:
            infersent_preds = self.context2infersent(discriminator_input)

        # project context_outputs to decoder init state
        decoder_init = self.context2decoder(context_outputs)

        # [num_layers, batch_size, hidden_size]
        decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size)

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if rl_mode or not decode:
            decoder_outputs = self.decoder(target_sentences,
                                           init_h=decoder_init,
                                           decode=decode)
            return decoder_outputs, emoji_preds, infersent_preds

        else:
            # decoder_outputs = self.decoder(target_sentences,
            #                                init_h=decoder_init,
            #                                decode=decode)
            # return decoder_outputs.unsqueeze(1)
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)

            # Get top prediction only
            # [batch_size, max_unroll]
            # prediction = prediction[:, 0]

            # [batch_size, beam_size, max_unroll]
            return prediction, emoji_preds, infersent_preds
Пример #20
0
    def forward(self, sentences, sentence_length, input_conversation_length, 
                target_sentences, decode=False, extra_context_inputs=None,
                rl_mode=False):
        """
        Args:
            sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        batch_size = input_conversation_length.size(0)
        num_sentences = sentences.size(0) - batch_size
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size]
        # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size]
        encoder_outputs, raw_encoder_hidden = self.encoder(sentences,
                                                           sentence_length)

        # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size]
        context_inputs_2d = raw_encoder_hidden.transpose(
            1, 0).contiguous().view(num_sentences + batch_size, -1)
        
        if self.config.context_input_only:
            context_inputs_2d = torch.cat(
                (context_inputs_2d, extra_context_inputs), 1)

        # pad and pack encoder_hidden
        start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                                        input_conversation_length[:-1] + 1)), 0)
        # context_inputs: [batch_size, max_len + 1, num_layers * direction * hidden_size]
        context_inputs = torch.stack([pad(context_inputs_2d.narrow(0, s, l + 1), max_len + 1)
                                      for s, l in zip(start.data.tolist(),
                                                      input_conversation_length.data.tolist())], 0)

        # context_inputs_inference: [batch_size, max_len, num_layers * direction * hidden_size]
        context_inputs_inference = context_inputs[:, 1:, :]
        context_inputs_inference_flat = torch.cat(
            [context_inputs_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)])

        # context_inputs_input: [batch_size, max_len, num_layers * direction * hidden_size]
        context_inputs_input = context_inputs[:, :-1, :]

        # Standard Gaussian prior
        conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size]))
        conv_mu_prior, conv_var_prior = self.conv_prior()

        if not rl_mode and not decode:
            if self.config.sentence_drop > 0.0:
                indices = np.where(np.random.rand(max_len) < self.config.sentence_drop)[0]
                if len(indices) > 0:
                    context_inputs_input[:, indices, :] = self.unk_sent

            # context_inference_outputs: [batch_size, max_len, num_directions * context_size]
            # context_inference_hidden: [num_layers * num_directions, batch_size, hidden_size]
            context_inference_outputs, context_inference_hidden = self.context_inference(
                context_inputs, input_conversation_length + 1)

            # context_inference_hidden: [batch_size, num_layers * num_directions * hidden_size]
            context_inference_hidden = context_inference_hidden.transpose(
                1, 0).contiguous().view(batch_size, -1)
            conv_mu_posterior, conv_var_posterior = self.conv_posterior(context_inference_hidden)
            z_conv = conv_mu_posterior + torch.sqrt(conv_var_posterior) * conv_eps
            log_q_zx_conv = normal_logpdf(z_conv, conv_mu_posterior, conv_var_posterior).sum()

            log_p_z_conv = normal_logpdf(z_conv, conv_mu_prior, conv_var_prior).sum()
            kl_div_conv = normal_kl_div(conv_mu_posterior, conv_var_posterior,
                                            conv_mu_prior, conv_var_prior).sum()

            context_init = self.z_conv2context(z_conv).view(
                self.config.num_layers, batch_size, self.config.context_size)

            z_conv_expand = z_conv.view(z_conv.size(0), 1, z_conv.size(
                1)).expand(z_conv.size(0), max_len, z_conv.size(1))
            context_outputs, context_last_hidden = self.context_encoder(
                torch.cat([context_inputs_input, z_conv_expand], 2),
                input_conversation_length,
                hidden=context_init)

            # flatten outputs
            # context_outputs: [num_sentences, context_size]
            context_outputs = torch.cat([context_outputs[i, :l, :]
                                         for i, l in enumerate(input_conversation_length.data)])

            z_conv_flat = torch.cat(
                [z_conv_expand[i, :l, :] for i, l in enumerate(input_conversation_length.data)])
            sent_mu_prior, sent_var_prior = self.sent_prior(context_outputs, z_conv_flat)
            eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))

            sent_mu_posterior, sent_var_posterior = self.sent_posterior(
                context_outputs, context_inputs_inference_flat, z_conv_flat)
            z_sent = sent_mu_posterior + torch.sqrt(sent_var_posterior) * eps
            log_q_zx_sent = normal_logpdf(z_sent, sent_mu_posterior, sent_var_posterior).sum()

            log_p_z_sent = normal_logpdf(z_sent, sent_mu_prior, sent_var_prior).sum()
            # kl_div: [num_sentences]
            kl_div_sent = normal_kl_div(sent_mu_posterior, sent_var_posterior,
                                        sent_mu_prior, sent_var_prior).sum()

            kl_div = kl_div_conv + kl_div_sent
            log_q_zx = log_q_zx_conv + log_q_zx_sent
            log_p_z = log_p_z_conv + log_p_z_sent
        else:
            z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps
            context_init = self.z_conv2context(z_conv).view(
                self.config.num_layers, batch_size, self.config.context_size)

            z_conv_expand = z_conv.view(z_conv.size(0), 1, z_conv.size(
                1)).expand(z_conv.size(0), max_len, z_conv.size(1))
            # context_outputs: [batch_size, max_len, context_size]
            context_outputs, context_last_hidden = self.context_encoder(
                torch.cat([context_inputs_input, z_conv_expand], 2),
                input_conversation_length,
                hidden=context_init)
            # flatten outputs
            # context_outputs: [num_sentences, context_size]
            context_outputs = torch.cat([context_outputs[i, :l, :]
                                         for i, l in enumerate(input_conversation_length.data)])


            z_conv_flat = torch.cat(
                [z_conv_expand[i, :l, :] for i, l in enumerate(input_conversation_length.data)])
            sent_mu_prior, sent_var_prior = self.sent_prior(context_outputs, z_conv_flat)
            eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))

            z_sent = sent_mu_prior + torch.sqrt(sent_var_prior) * eps
            kl_div = None
            log_p_z = normal_logpdf(z_sent, sent_mu_prior, sent_var_prior).sum()
            log_p_z += normal_logpdf(z_conv, conv_mu_prior, conv_var_prior).sum()
            log_q_zx = None

        # Predict emojis using discriminator.
        emoji_preds = None
        if self.config.emotion:
            emoji_preds = self.context2emoji(context_outputs)
        
        # Predict sentence embeddings using discriminator.
        infersent_preds = None
        if self.config.infersent:
            infersent_preds = self.context2infersent(context_outputs)

        # expand z_conv to all associated sentences
        z_conv = torch.cat([z.view(1, -1).expand(m.item(), self.config.z_conv_size)
                             for z, m in zip(z_conv, input_conversation_length)])

        # latent_context: [num_sentences, context_size + z_sent_size +
        # z_conv_size]
        latent_context = torch.cat([context_outputs, z_sent, z_conv], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1,
                                         self.decoder.num_layers,
                                         self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if rl_mode or not decode:
            decoder_outputs = self.decoder(target_sentences,
                                            init_h=decoder_init,
                                            decode=decode)
            return (decoder_outputs, kl_div, log_p_z, log_q_zx, 
                    emoji_preds, infersent_preds)

        else:
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(
                init_h=decoder_init)
            return (prediction, kl_div, log_p_z, log_q_zx, 
                    emoji_preds, infersent_preds)
Пример #21
0
 def conv_prior(self):
     # Standard gaussian prior
     return to_var(torch.FloatTensor([0.0])), to_var(torch.FloatTensor([1.0]))
Пример #22
0
    def forward(self, sentences, sentence_length, input_conversation_length, 
                target_sentences, decode=False, extra_context_inputs=None, 
                rl_mode=False):
        """
        Args:
            sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        batch_size = input_conversation_length.size(0)
        num_sentences = sentences.size(0) - batch_size
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size]
        # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size]
        encoder_outputs, raw_encoder_hidden = self.encoder(sentences,
                                                           sentence_length)

        # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size]
        context_inputs_2d = raw_encoder_hidden.transpose(
            1, 0).contiguous().view(num_sentences + batch_size, -1)

        if self.config.context_input_only:
            context_inputs_2d = torch.cat(
                (context_inputs_2d, extra_context_inputs), 1)

        # pad and pack encoder_hidden
        start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                                        input_conversation_length[:-1] + 1)), 0)

        # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size]
        context_inputs = torch.stack([pad(context_inputs_2d.narrow(0, s, l + 1), max_len + 1)
                                      for s, l in zip(start.data.tolist(),
                                                      input_conversation_length.data.tolist())], 0)

        # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size]
        context_inputs_inference = context_inputs[:, 1:, :]
        context_inputs_inference_flat = torch.cat(
            [context_inputs_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)])

        # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size]
        context_inputs_input = context_inputs[:, :-1, :]

        # context_outputs: [batch_size, max_len, context_size]
        context_outputs_with_targets, context_last_hidden = self.context_encoder(
            context_inputs_input, input_conversation_length)
        # flatten outputs
        # context_outputs: [num_sentences, context_size]
        context_outputs = torch.cat([context_outputs_with_targets[i, :l, :]
                                     for i, l in enumerate(input_conversation_length.data)])

        # Stop gradients from flowing from discriminator if only using input
        if self.config.context_input_only:
            discriminator_input = context_outputs.detach()
        else:
            discriminator_input = context_outputs
        
        # Predict emojis using discriminator.
        emoji_preds = None
        if self.config.emotion:
            emoji_preds = self.context2emoji(discriminator_input)

        # Predict sentence embeddings using discriminator
        infersent_preds = None
        if self.config.infersent:
            infersent_preds = self.context2infersent(discriminator_input)

        mu_prior, var_prior = self.prior(context_outputs)
        eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))
        if not rl_mode and not decode:
            mu_posterior, var_posterior = self.posterior(
                context_outputs, context_inputs_inference_flat)
            z_sent = mu_posterior + torch.sqrt(var_posterior) * eps
            log_q_zx = normal_logpdf(z_sent, mu_posterior, var_posterior).sum()

            log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum()
            # kl_div: [num_sentneces]
            kl_div = normal_kl_div(mu_posterior, var_posterior,
                                    mu_prior, var_prior)
            kl_div = torch.sum(kl_div)
        else:
            z_sent = mu_prior + torch.sqrt(var_prior) * eps
            kl_div = None
            log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum()
            log_q_zx = None

        self.z_sent = z_sent
        latent_context = torch.cat([context_outputs, z_sent], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1,
                                         self.decoder.num_layers,
                                         self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if rl_mode or not decode:
            decoder_outputs = self.decoder(target_sentences,
                                           init_h=decoder_init,
                                           decode=decode)

            return (decoder_outputs, kl_div, log_p_z, log_q_zx, 
                    emoji_preds, infersent_preds)

        else:
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(
                init_h=decoder_init)

            return (prediction, kl_div, log_p_z, log_q_zx, 
                    emoji_preds, infersent_preds)