def process_raw_text_into_input(self, raw_text_sentences, max_conversation_length=5, debug=False,): sentences, lens = self.pretrained_prior.process_user_input( raw_text_sentences, self.rl_config.max_sentence_length) # Remove any sentences of length 0 sentences = [sent for i, sent in enumerate(sentences) if lens[i] > 0] good_raw_sentences = [sent for i, sent in enumerate( raw_text_sentences) if lens[i] > 0] lens = [l for l in lens if l > 0] # Trim conversation to max length sentences = sentences[-max_conversation_length:] lens = lens[-max_conversation_length:] good_raw_sentences = good_raw_sentences[-max_conversation_length:] convo_length = len(sentences) # Convert to torch variables input_sentences = to_var(torch.LongTensor(sentences)) input_sentence_length = to_var(torch.LongTensor(lens)) input_conversation_length = to_var(torch.LongTensor([convo_length])) if debug: print('\n**Conversation history:**') for sent in sentences: print(self.vocab.decode(list(sent))) return (input_sentences, input_sentence_length, input_conversation_length)
def init_h(self, batch_size=None, hidden=None): """Return RNN initial state""" if hidden is not None: return hidden if self.use_lstm: return (to_var( torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size)), to_var( torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size))) else: return to_var( torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size))
def init_h(self, batch_size=None, zero=True, hidden=None): """Return RNN initial state""" if hidden is not None: return hidden if self.use_lstm: # (h, c) return (to_var( torch.zeros(self.num_layers, batch_size, self.hidden_size)), to_var( torch.zeros(self.num_layers, batch_size, self.hidden_size))) else: # h return to_var( torch.zeros(self.num_layers, batch_size, self.hidden_size))
def compute_rewards(self, conversations, rewards_lst, reward_weights, gamma=0.0): supported = { 'reward_question', 'reward_you', 'reward_toxicity', 'reward_bot_deepmoji', 'reward_user_deepmoji', 'reward_conversation_repetition', 'reward_utterance_repetition', 'reward_infersent_coherence', 'reward_deepmoji_coherence', 'reward_word2vec_coherence', 'reward_bot_response_length', 'reward_word_similarity', 'reward_USE_similarity' } episode_len = self.config.episode_len num_convs = self.config.rl_batch_size combined_rewards = np.zeros((num_convs, episode_len)) for r, w in zip(rewards_lst, reward_weights): if r not in supported: raise NotImplementedError() reward_func = getattr(hrl_rewards, r) rewards = reward_func(conversations) discounted = discount(rewards, gamma) normalized = normalizeZ(discounted) combined_rewards += float(w) * normalized self.rewards_history[r].append(rewards.mean().item()) # [num_convs, num_actions] = [rl_batch_size, episode_len] return to_var(torch.FloatTensor(combined_rewards))
def extract_sentence_data(self, batch): with torch.no_grad(): # Extract batch info actions = to_var(torch.LongTensor(batch['action'])) # [batch_size] action_lens = batch['action_lens'] conversations = [ np.concatenate((conv, np.atleast_2d(batch['action'][i]))) for i, conv in enumerate(batch['state']) ] sent_lens = [ np.concatenate((lens, np.atleast_1d(batch['action_lens'][i]))) for i, lens in enumerate(batch['state_lens']) ] target_conversations = [conv[1:] for conv in conversations] targets = [sent for conv in target_conversations for sent in conv] targets = to_var(torch.LongTensor(targets)) conv_lens = [len(c) - 1 for c in conversations] conv_lens = to_var(torch.LongTensor(conv_lens)) # Compute non-variational inputs hred_convs = [conv[:-1] for conv in conversations] hred_sent_lens = np.concatenate([l[:-1] for l in sent_lens]) hred_sent_lens = to_var(torch.LongTensor(hred_sent_lens)) hred_sentences = [sent for conv in hred_convs for sent in conv] hred_sentences = to_var(torch.LongTensor(hred_sentences)) # Compute variational inputs sent_lens = np.concatenate([l for l in sent_lens]) sent_lens = to_var(torch.LongTensor(sent_lens)) sentences = [sent for conv in conversations for sent in conv] sentences = to_var(torch.LongTensor(sentences)) return (actions, action_lens, sentences, sent_lens, hred_sentences, hred_sent_lens, targets, conv_lens)
def dynamically_assess_context_inputs(gen_response, botmoji, botsent, vocab, config): gen_response = gen_response.view(-1,30).cpu().numpy() # Translate tokens to words and detokenize decoded_response = [vocab.decode(list(g)) for g in gen_response] # needs to be higher dimensional? decoded_response = [detokenize(d) for d in decoded_response] # Assess DeepMoji and Infersent on text try: infersent_sentences = to_var(torch.FloatTensor( [botsent.encode(s) for s in decoded_response])) blank_deepmoji = [1.0 / config.emo_output_size] * config.emo_output_size emoji_sentences = to_var(torch.FloatTensor( [botmoji.encode(s) if s != '' else blank_deepmoji for s in decoded_response])) except Exception as e: print("Error in dynamic context iputs:") print(str(e)) return torch.cat((emoji_sentences, infersent_sentences), 1)
def run_seq2seq_model(self, q_net, input_conversations, sent_lens, target_conversations, conv_lens): # Prepare the batch sentences = [sent for conv in input_conversations for sent in conv] targets = [sent for conv in target_conversations for sent in conv] if not (np.all(np.isfinite(sentences)) and np.all(np.isfinite(targets)) and np.all(np.isfinite(sent_lens))): print("Input isn't finite") sentences = to_var(torch.LongTensor(sentences)) targets = to_var(torch.LongTensor(targets)) sent_lens = to_var(torch.LongTensor(sent_lens)) # Run Q network q_outputs = q_net.model(sentences, sent_lens, conv_lens, targets, rl_mode=True) return q_outputs[0] # [num_sentences, max_sentence_len, vocab_size]
def embed(self, x): """word index: [batch_size] => word vectors: [batch_size, hidden_size]""" if self.training and self.word_drop > 0.0: if random.random() < self.word_drop: embed = self.embedding(to_var(x.data.new([UNK_ID] * x.size(0)))) else: embed = self.embedding(x) else: embed = self.embedding(x) return embed
def run_model_on_sentences(self, bot, batch_tensors): with torch.no_grad(): (actions, action_lens, sentences, sent_lens, hred_sentences, hred_sent_lens, targets, conv_lens) = batch_tensors if bot.config.model not in VariationalModels: sentences = hred_sentences sent_lens = hred_sent_lens # Run model outputs = bot.solver.model(sentences, sent_lens, conv_lens, targets, rl_mode=True) logits = outputs[0] # Index to get only output values for actions taken (last sentence # in each conversation) start = torch.cumsum( torch.cat( (to_var(conv_lens.data.new(1).zero_()), conv_lens[:-1])), 0) action_logits = torch.stack([ logits[s + l - 1, :, :] for s, l in zip(start.data.tolist(), conv_lens.data.tolist()) ], 0) # [num_sentences, max_sent_len, vocab_size] # Limit by actual sentence length (remove padding) and flatten into # long list of words word_logits = torch.cat( [action_logits[i, :l, :] for i, l in enumerate(action_lens)], 0) # [total words, vocab_size] word_actions = torch.cat( [actions[i, :l] for i, l in enumerate(action_lens)], 0) # [total words] # Take softmax to get probability distribution # [total_words, vocab_size] word_probs = torch.nn.functional.softmax(word_logits, 1) # Extract q values corresponding to actions taken relevant_words = word_probs.gather( 1, word_actions.unsqueeze(1)).squeeze() # [total words] return relevant_words
def duplicate_context_for_beams(self, sentences, sent_lens, conv_lens, beams): conv_lens = conv_lens.repeat(len(beams)) # [beam_size * sentences, sentence_len] if len(sentences) > 1: targets = torch.cat( [torch.cat([sentences[1:,:], beams[i,:].unsqueeze(0)], 0) for i in range(len(beams))], 0) else: targets = beams # HRED if self.rl_config.model not in VariationalModels: sent_lens = sent_lens.repeat(len(beams)) return sentences, sent_lens, conv_lens, targets # VHRED, VHCR new_sentences = torch.cat( [torch.cat([sentences, beams[i,:].unsqueeze(0)], 0) for i in range(len(beams))], 0) new_len = to_var(torch.LongTensor([self.rl_config.max_sentence_length])) sent_lens = torch.cat( [torch.cat([sent_lens, new_len], 0) for i in range(len(beams))]) return new_sentences, sent_lens, conv_lens, targets
def generate(self, context, sentence_length, n_context, extra_context_inputs=None, botmoji=None, botsent=None, vocab=None): # context: [batch_size, n_context, seq_len] batch_size = context.size(0) # n_context = context.size(1) samples = [] # Run for context conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size])) # conv_mu_prior, conv_var_prior = self.conv_prior() # z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps context_inputs_list = [] for i in range(n_context): # encoder_outputs: [batch_size, seq_len, hidden_size * direction] # encoder_hidden: [num_layers * direction, batch_size, hidden_size] encoder_outputs, raw_encoder_hidden = self.encoder(context[:, i, :], sentence_length[:, i]) # encoder_hidden: [batch_size, num_layers * direction * hidden_size] context_inputs_2d = raw_encoder_hidden.transpose( 1, 0).contiguous().view(batch_size, -1) if self.config.context_input_only: context_inputs_2d = torch.cat( (context_inputs_2d, extra_context_inputs), 1) context_inputs_list.append(context_inputs_2d) context_inputs = torch.stack(context_inputs_list, 1) (context_inference_outputs, context_inference_hidden) = self.context_inference( context_inputs, to_var(torch.LongTensor([n_context] * batch_size))) context_inference_hidden = context_inference_hidden.transpose( 1, 0).contiguous().view(batch_size, -1) conv_mu_posterior, conv_var_posterior = self.conv_posterior(context_inference_hidden) z_conv = conv_mu_posterior + torch.sqrt(conv_var_posterior) * conv_eps context_init = self.z_conv2context(z_conv).view( self.config.num_layers, batch_size, self.config.context_size) context_hidden = context_init for i in range(n_context): # encoder_outputs: [batch_size, seq_len, hidden_size * direction] # encoder_hidden: [num_layers * direction, batch_size, hidden_size] encoder_outputs, raw_encoder_hidden = self.encoder(context[:, i, :], sentence_length[:, i]) # encoder_hidden: [batch_size, num_layers * direction * context_inputs_2d = raw_encoder_hidden.transpose( 1, 0).contiguous().view(batch_size, -1) if self.config.context_input_only: context_inputs_2d = torch.cat( (context_inputs_2d, extra_context_inputs), 1) context_inputs_list.append(context_inputs_2d) # context_outputs: [batch_size, 1, context_hidden_size * direction] # context_hidden: [num_layers * direction, batch_size, context_hidden_size] context_outputs, context_hidden = self.context_encoder.step( torch.cat([context_inputs_2d, z_conv], 1), context_hidden) # Run for generation for j in range(self.config.n_sample_step): # context_outputs: [batch_size, context_hidden_size * direction] context_outputs = context_outputs.squeeze(1) mu_prior, var_prior = self.sent_prior(context_outputs, z_conv) eps = to_var(torch.randn((batch_size, self.config.z_sent_size))) z_sent = mu_prior + torch.sqrt(var_prior) * eps latent_context = torch.cat([context_outputs, z_sent, z_conv], 1) decoder_init = self.context2decoder(latent_context) decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size) if self.config.sample: prediction = self.decoder(None, decoder_init, decode=True) p = prediction.data.cpu().numpy() length = torch.from_numpy(np.where(p == EOS_ID)[1]) else: prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) # prediction: [batch_size, seq_len] prediction = prediction[:, 0, :] # length: [batch_size] length = [l[0] for l in length] length = to_var(torch.LongTensor(length)) samples.append(prediction) encoder_outputs, raw_encoder_hidden = self.encoder(prediction, length) context_inputs_2d = raw_encoder_hidden.transpose( 1, 0).contiguous().view(batch_size, -1) # Dynamically assess the DeepMoji and Infersent predictions on # generated text if self.config.context_input_only: dynamic_context_inputs = dynamically_assess_context_inputs( prediction, botmoji, botsent, vocab, self.config) context_inputs_2d = torch.cat( (context_inputs_2d, dynamic_context_inputs), 1) context_outputs, context_hidden = self.context_encoder.step( torch.cat([context_inputs_2d, z_conv], 1), context_hidden) samples = torch.stack(samples, 1) return samples
def get_target_q_values(self, batch, prior_rewards=None): rewards = to_var(torch.FloatTensor(batch['rewards'])) # [batch_size] not_done = to_var(torch.FloatTensor(1 - batch['done'])) # [batch_size] self.sampled_reward_batch_history.append(torch.sum(rewards).item()) # Prepare inputs to target Q network. Append a blank sentence to get # best response at next utterance to user input. (Next state # includes user input). blank_sentence = np.zeros((1, self.config.max_sentence_length)) next_state_convs = [ np.concatenate((conv, blank_sentence)) for conv in batch['next_state'] ] next_state_lens = [ np.concatenate((lens, [1])) for lens in batch['next_state_lens'] ] next_targets = [conv[1:] for conv in next_state_convs] next_conv_lens = [len(c) - 1 for c in next_state_convs] if self.config.model not in VariationalModels: next_state_convs = [conv[:-1] for conv in next_state_convs] next_state_lens = np.concatenate([l[:-1] for l in next_state_lens]) else: next_state_lens = np.concatenate([l for l in next_state_lens]) next_conv_lens = to_var(torch.LongTensor(next_conv_lens)) # [monte_carlo_count, num_sentences, max sent len, vocab size] _mc_target_q_values = [[]] * self.config.monte_carlo_count for t in range(self.config.monte_carlo_count): # Run target Q network. Output is size: # [num_sentences, max sent len, vocab size] if self.config.monte_carlo_count == 1: # In this setting, we don't use dropout out at inference time at all all_target_q_values = self.run_seq2seq_model( self.target_q_net, next_state_convs, next_state_lens, next_targets, next_conv_lens) else: # In this setting, each time we draw a new dropout mask (at inference time) all_target_q_values = self.run_seq2seq_model( self.target_q_net, next_state_convs, next_state_lens, next_targets, next_conv_lens) # Target indexing: last sentence is a blank to get value of next # response. Second last is the user response. 3rd last is models own # actions. Note that targets begin at the 2nd word of each sentence. start_t = torch.cumsum( torch.cat((to_var(next_conv_lens.data.new(1).zero_()), next_conv_lens[:-1])), 0) conv_target_q_values = torch.stack([ all_target_q_values[s + l - 3, 1:, :] for s, l in zip( start_t.data.tolist(), next_conv_lens.data.tolist()) ], 0) # Dimension [num_sentences, max_sent_len - 1, vocab_size] # At the end of a sentence, want value of starting a new response # after user's response. So index into first word of last blank # sentence that was appended to the end of the conversation. next_response_targets = torch.stack([ all_target_q_values[s + l - 1, 0, :] for s, l in zip( start_t.data.tolist(), next_conv_lens.data.tolist()) ], 0) next_response_targets = torch.reshape( next_response_targets, [self.config.rl_batch_size, 1, -1 ]) # [num_sentences, 1, vocab_size] conv_target_q_values = torch.cat( [conv_target_q_values, next_response_targets], 1) # [num_sentences, max_sent_len, vocab_size] # Limit target Q values by conversation length limit_conv_targets = [ conv_target_q_values[i, :l, :] for i, l in enumerate(batch['action_lens']) ] if self.config.psi_learning: # Target is r + gamma * log sum_a' exp(Q_target(s', a')) conv_max_targets = [ torch.distributions.utils.log_sum_exp(c) for c in limit_conv_targets ] target_q_values = torch.cat([ rewards[i] + not_done[i] * self.config.gamma * c.squeeze() for i, c in enumerate(conv_max_targets) ], 0) # [total words] else: # Target is r + gamma * max_a' Q_target(s',a'). Reward and done are # at the level of conversation, so add and multiply in before # flattening and taking max. word_target_q_values = torch.cat([ rewards[i] + not_done[i] * self.config.gamma * c for i, c in enumerate(limit_conv_targets) ], 0) # [total words, vocab_size] target_q_values, _ = word_target_q_values.max(1) _mc_target_q_values[t] = target_q_values mc_target_q_values = torch.stack(_mc_target_q_values, 0) min_target_q_values, _ = mc_target_q_values.min(0) if self.config.kl_control: min_target_q_values += prior_rewards return min_target_q_values
def get_q_values(self, batch): """update where states are whole conversations which each have several sentences, and actions are a sentence (series of words). Q values are per word. Target Q values are over the next word in the sentence, or, if at the end of the sentence, the first word in a new sentence after the user response. """ actions = to_var(torch.LongTensor(batch['action'])) # [batch_size] # Prepare inputs to Q network conversations = [ np.concatenate((conv, np.atleast_2d(batch['action'][i]))) for i, conv in enumerate(batch['state']) ] sent_lens = [ np.concatenate((lens, np.atleast_1d(batch['action_lens'][i]))) for i, lens in enumerate(batch['state_lens']) ] target_conversations = [conv[1:] for conv in conversations] conv_lens = [len(c) - 1 for c in conversations] if self.config.model not in VariationalModels: conversations = [conv[:-1] for conv in conversations] sent_lens = np.concatenate([l[:-1] for l in sent_lens]) else: sent_lens = np.concatenate([l for l in sent_lens]) conv_lens = to_var(torch.LongTensor(conv_lens)) # Run Q network. Will produce [num_sentences, max sent len, vocab size] all_q_values = self.run_seq2seq_model(self.q_net, conversations, sent_lens, target_conversations, conv_lens) # Index to get only q values for actions taken (last sentence in each # conversation) start_q = torch.cumsum( torch.cat((to_var(conv_lens.data.new(1).zero_()), conv_lens[:-1])), 0) conv_q_values = torch.stack([ all_q_values[s + l - 1, :, :] for s, l in zip(start_q.data.tolist(), conv_lens.data.tolist()) ], 0) # [num_sentences, max_sent_len, vocab_size] # Limit by actual sentence length (remove padding) and flatten into # long list of words word_q_values = torch.cat([ conv_q_values[i, :l, :] for i, l in enumerate(batch['action_lens']) ], 0) # [total words, vocab_size] word_actions = torch.cat( [actions[i, :l] for i, l in enumerate(batch['action_lens'])], 0) # [total words] # Extract q values corresponding to actions taken q_values = word_q_values.gather( 1, word_actions.unsqueeze(1)).squeeze() # [total words] """ Compute KL metrics """ prior_rewards = None # Get probabilities from policy network q_dists = torch.nn.functional.softmax(word_q_values, 1) q_probs = q_dists.gather(1, word_actions.unsqueeze(1)).squeeze() with torch.no_grad(): # Run pretrained prior network. # [num_sentences, max sent len, vocab size] all_prior_logits = self.run_seq2seq_model(self.pretrained_prior, conversations, sent_lens, target_conversations, conv_lens) # Get relevant actions. [num_sentences, max_sent_len, vocab_size] conv_prior = torch.stack([ all_prior_logits[s + l - 1, :, :] for s, l in zip(start_q.data.tolist(), conv_lens.data.tolist()) ], 0) # Limit by actual sentence length (remove padding) and flatten. # [total words, vocab_size] word_prior_logits = torch.cat([ conv_prior[i, :l, :] for i, l in enumerate(batch['action_lens']) ], 0) # Take the softmax prior_dists = torch.nn.functional.softmax(word_prior_logits, 1) kl_div = F.kl_div(q_dists.log(), prior_dists, reduce=False) # [total words] prior_probs = prior_dists.gather( 1, word_actions.unsqueeze(1)).squeeze() logp_logq = prior_probs.log() - q_probs.log() if self.config.model_averaging: model_avg_sentences = batch['model_averaged_probs'] # Convert to tensors and flatten into [num_words] word_model_avg = torch.cat([ to_var(torch.FloatTensor(m)) for m in model_avg_sentences ], 0) # Compute KL from model-averaged prior prior_rewards = word_model_avg.log() - q_probs.log() # Clip because KL should never be negative, so because we # are subtracting KL, rewards should never be positive prior_rewards = torch.clamp(prior_rewards, max=0.0) elif self.config.kl_control and self.config.kl_calc == 'integral': # Note: we reward the negative KL divergence to ensure the # RL model stays close to the prior prior_rewards = -1.0 * torch.sum(kl_div, dim=1) elif self.config.kl_control: prior_rewards = logp_logq if self.config.kl_control: prior_rewards = prior_rewards * self.config.kl_weight_c self.kl_reward_batch_history.append( torch.sum(prior_rewards).item()) # Track all metrics self.kl_div_batch_history.append(torch.mean(kl_div).item()) self.logp_batch_history.append( torch.mean(prior_probs.log()).item()) self.logp_logq_batch_history.append(torch.mean(logp_logq).item()) return q_values, prior_rewards
def init_token(self, batch_size, SOS_ID=SOS_ID): """Get Variable of <SOS> Index (batch_size)""" x = to_var(torch.LongTensor([SOS_ID] * batch_size)) return x
def beam_decode(self, init_h=None, encoder_outputs=None, input_valid_length=None, decode=False): """ Args: encoder_outputs (Variable, FloatTensor): [batch_size, source_length, hidden_size] input_valid_length (Variable, LongTensor): [batch_size] (optional) init_h (variable, FloatTensor): [batch_size, hidden_size] (optional) Return: out : [batch_size, seq_len] """ batch_size = self.batch_size(h=init_h) # [batch_size x beam_size] x = self.init_token(batch_size * self.beam_size, SOS_ID) # [num_layers, batch_size x beam_size, hidden_size] h = self.init_h(batch_size, hidden=init_h).repeat(1, self.beam_size, 1) # batch_position [batch_size] # [0, beam_size, beam_size * 2, .., beam_size * (batch_size-1)] # Points where batch starts in [batch_size x beam_size] tensors # Ex. position_idx[5]: when 5-th batch starts batch_position = to_var( torch.arange(0, batch_size).long() * self.beam_size) # Initialize scores of sequence # [batch_size x beam_size] # Ex. batch_size: 5, beam_size: 3 # [0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf] score = torch.ones(batch_size * self.beam_size) * -float('inf') score.index_fill_(0, torch.arange(0, batch_size).long() * self.beam_size, 0.0) score = to_var(score) # Initialize Beam that stores decisions for backtracking beam = Beam(batch_size, self.hidden_size, self.vocab_size, self.beam_size, self.max_unroll, batch_position) for i in range(self.max_unroll): # x: [batch_size x beam_size]; (token index) # => # out: [batch_size x beam_size, vocab_size] # h: [num_layers, batch_size x beam_size, hidden_size] out, h = self.forward_step(x, h, encoder_outputs=encoder_outputs, input_valid_length=input_valid_length) # log_prob: [batch_size x beam_size, vocab_size] log_prob = F.log_softmax(out, dim=1) # [batch_size x beam_size] # => [batch_size x beam_size, vocab_size] score = score.view(-1, 1) + log_prob # Select `beam size` transitions out of `vocab size` combinations # [batch_size x beam_size, vocab_size] # => [batch_size, beam_size x vocab_size] # Cutoff and retain candidates with top-k scores # score: [batch_size, beam_size] # top_k_idx: [batch_size, beam_size] # each element of top_k_idx [0 ~ beam x vocab) score, top_k_idx = score.view(batch_size, -1).topk(self.beam_size, dim=1) # Get token ids with remainder after dividing by top_k_idx # Each element is among [0, vocab_size) # Ex. Index of token 3 in beam 4 # (4 * vocab size) + 3 => 3 # x: [batch_size x beam_size] x = (top_k_idx % self.vocab_size).view(-1) # top-k-pointer [batch_size x beam_size] # Points top-k beam that scored best at current step # Later used as back-pointer at backtracking # Each element is beam index: 0 ~ beam_size # + position index: 0 ~ beam_size x (batch_size-1) beam_idx = top_k_idx / self.vocab_size # [batch_size, beam_size] top_k_pointer = (beam_idx + batch_position.unsqueeze(1)).view(-1) # Select next h (size doesn't change) # [num_layers, batch_size * beam_size, hidden_size] h = h.index_select(1, top_k_pointer) # Update sequence scores at beam beam.update(score.clone(), top_k_pointer, x) # , h) # Erase scores for EOS so that they are not expanded # [batch_size, beam_size] eos_idx = x.data.eq(EOS_ID).view(batch_size, self.beam_size) if eos_idx.nonzero().dim() > 0: score.data.masked_fill_(eos_idx, -float('inf')) # prediction ([batch, k, max_unroll]) # A list of Tensors containing predicted sequence # final_score [batch, k] # A list containing the final scores for all top-k sequences # length [batch, k] # A list specifying the length of each sequence in the top-k candidates # prediction, final_score, length = beam.backtrack() prediction, final_score, length = beam.backtrack() return prediction, final_score, length
def generate(self, context, sentence_length, n_context, extra_context_inputs=None, botsent=None, botmoji=None, vocab=None): # context: [batch_size, n_context, seq_len] batch_size = context.size(0) # n_context = context.size(1) samples = [] # Run for context context_hidden=None for i in range(n_context): # encoder_outputs: [batch_size, seq_len, hidden_size * direction] # encoder_hidden: [num_layers * direction, batch_size, hidden_size] encoder_outputs, raw_encoder_hidden = self.encoder(context[:, i, :], sentence_length[:, i]) context_inputs_2d = raw_encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) if self.config.context_input_only: context_inputs_2d = torch.cat( (context_inputs_2d, extra_context_inputs), 1) # context_outputs: [batch_size, 1, context_hidden_size * direction] # context_hidden: [num_layers * direction, batch_size, context_hidden_size] context_outputs, context_hidden = self.context_encoder.step(context_inputs_2d, context_hidden) # Run for generation for j in range(self.config.n_sample_step): # context_outputs: [batch_size, context_hidden_size * direction] context_outputs = context_outputs.squeeze(1) decoder_init = self.context2decoder(context_outputs) decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size) prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) # prediction: [batch_size, seq_len] prediction = prediction[:, 0, :] # length: [batch_size] length = [l[0] for l in length] length = to_var(torch.LongTensor(length)) samples.append(prediction) encoder_outputs, raw_encoder_hidden = self.encoder(prediction, length) context_inputs_2d = raw_encoder_hidden.transpose( 1, 0).contiguous().view(batch_size, -1) # Dynamically assess the DeepMoji and Infersent predictions on # generated text if self.config.context_input_only: dynamic_context_inputs = dynamically_assess_context_inputs( prediction, botmoji, botsent, vocab, self.config) context_inputs_2d = torch.cat( (context_inputs_2d, dynamic_context_inputs), 1) context_outputs, context_hidden = self.context_encoder.step( context_inputs_2d, context_hidden) samples = torch.stack(samples, 1) return samples
def generate_response_to_input(self, raw_text_sentences, max_conversation_length=5, sample_by='priority', emojize=True, debug=True): with torch.no_grad(): (input_sentences, input_sent_lens, input_conv_lens) = self.process_raw_text_into_input( raw_text_sentences, debug=debug, max_conversation_length=max_conversation_length) # Initialize a tensor for beams beams = to_var(torch.LongTensor( np.ones((self.rl_config.beam_size, self.rl_config.max_sentence_length)))) # Create a batch with the context duplicated for each beam (sentences, sent_lens, conv_lens, targets) = self.duplicate_context_for_beams( input_sentences, input_sent_lens, input_conv_lens, beams) # Continuously feed beam sentences into networks to sample the next # best word, add that to the beam, and continue for i in range(self.rl_config.max_sentence_length): # Run both models to obtain logits prior_output = self.pretrained_prior.model( sentences, sent_lens, conv_lens, targets, rl_mode=True) all_prior_logits = prior_output[0] q_output = self.q_net.model( sentences, sent_lens, conv_lens, targets, rl_mode=True) all_q_logits = q_output[0] # Select only those logits for next word q_logits = all_q_logits[:, i, :].squeeze() prior_logits = all_prior_logits[:, i, :].squeeze() # Get prior distribution for next word in each beam prior_dists = torch.nn.functional.softmax(prior_logits, 1) for b in range(self.rl_config.beam_size): # Sample from the prior bcq_n times for each beam dist = torch.distributions.Categorical(prior_dists[b,:]) sampled_idxs = dist.sample_n(self.rl_config.bcq_n) # Select sample with highest q value q_vals = torch.stack( [q_logits[b, idx] for idx in sampled_idxs]) _, best_word_i = torch.max(q_vals, 0) best_word = sampled_idxs[best_word_i] # Update beams beams[b, i] = best_word (sentences, sent_lens, conv_lens, targets) = self.duplicate_context_for_beams( input_sentences, input_sent_lens, input_conv_lens, beams) generated_sentences = beams.cpu().numpy() if debug: print('\n**All generated responses:**') for gen in generated_sentences: print(detokenize(self.vocab.decode(list(gen)))) gen_response = self.pretrained_prior.select_best_generated_response( generated_sentences, sample_by, beam_size=self.rl_config.beam_size) decoded_response = self.vocab.decode(list(gen_response)) decoded_response = detokenize(decoded_response) if emojize: inferred_emojis = self.pretrained_prior.botmoji.emojize_text( raw_text_sentences[-1], 5, 0.07) decoded_response = inferred_emojis + " " + decoded_response return decoded_response
def compute_bow_loss(self, target_conversations): target_bow = np.stack([to_bow(sent, self.config.vocab_size) for conv in target_conversations for sent in conv], axis=0) target_bow = to_var(torch.FloatTensor(target_bow)) bow_logits = self.bow_predict(self.bow_h(self.z_sent)) bow_loss = bag_of_words_loss(bow_logits, target_bow) return bow_loss
def forward(self, input_sentences, input_sentence_length, input_conversation_length, target_sentences, decode=False, extra_context_inputs=None, rl_mode=False): """ Args: input_sentences: (Variable, LongTensor) [num_sentences, seq_len] target_sentences: (Variable, LongTensor) [num_sentences, seq_len] Return: decoder_outputs: (Variable, FloatTensor) - train: [batch_size, seq_len, vocab_size] - eval: [batch_size, seq_len] """ num_sentences = input_sentences.size(0) max_len = input_conversation_length.data.max().item() # encoder_outputs: [num_sentences, max_source_length, hidden_size * direction] # encoder_hidden: [num_layers * direction, num_sentences, hidden_size] encoder_outputs, raw_encoder_hidden = self.encoder( input_sentences, input_sentence_length) # encoder_hidden: [num_sentences, num_layers * direction * hidden_size] context_inputs_2d = raw_encoder_hidden.transpose( 1, 0).contiguous().view(num_sentences, -1) if self.config.context_input_only: context_inputs_2d = torch.cat( (context_inputs_2d, extra_context_inputs), 1) # pad and pack encoder_hidden start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()), input_conversation_length[:-1])), 0) # encoder_hidden: [batch_size, max_len, num_layers * direction * hidden_size] context_inputs = torch.stack([pad(context_inputs_2d.narrow(0, s, l), max_len) for s, l in zip(start.data.tolist(), input_conversation_length.data.tolist())], 0) # context_outputs: [batch_size, max_len, context_size] context_outputs, context_last_hidden = self.context_encoder( context_inputs, input_conversation_length) # flatten outputs # context_outputs: [num_sentences, context_size] context_outputs = torch.cat([context_outputs[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) # Stop gradients from flowing from discriminator if only using input if self.config.context_input_only: discriminator_input = context_outputs.detach() else: discriminator_input = context_outputs # Predict emojis using discriminator. emoji_preds = None if self.config.emotion: emoji_preds = self.context2emoji(discriminator_input) # Predict infersent using discriminator. infersent_preds = None if self.config.infersent: infersent_preds = self.context2infersent(discriminator_input) # project context_outputs to decoder init state decoder_init = self.context2decoder(context_outputs) # [num_layers, batch_size, hidden_size] decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size) # train: [batch_size, seq_len, vocab_size] # eval: [batch_size, seq_len] if rl_mode or not decode: decoder_outputs = self.decoder(target_sentences, init_h=decoder_init, decode=decode) return decoder_outputs, emoji_preds, infersent_preds else: # decoder_outputs = self.decoder(target_sentences, # init_h=decoder_init, # decode=decode) # return decoder_outputs.unsqueeze(1) # prediction: [batch_size, beam_size, max_unroll] prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) # Get top prediction only # [batch_size, max_unroll] # prediction = prediction[:, 0] # [batch_size, beam_size, max_unroll] return prediction, emoji_preds, infersent_preds
def forward(self, sentences, sentence_length, input_conversation_length, target_sentences, decode=False, extra_context_inputs=None, rl_mode=False): """ Args: sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len] target_sentences: (Variable, LongTensor) [num_sentences, seq_len] Return: decoder_outputs: (Variable, FloatTensor) - train: [batch_size, seq_len, vocab_size] - eval: [batch_size, seq_len] """ batch_size = input_conversation_length.size(0) num_sentences = sentences.size(0) - batch_size max_len = input_conversation_length.data.max().item() # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size] # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size] encoder_outputs, raw_encoder_hidden = self.encoder(sentences, sentence_length) # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size] context_inputs_2d = raw_encoder_hidden.transpose( 1, 0).contiguous().view(num_sentences + batch_size, -1) if self.config.context_input_only: context_inputs_2d = torch.cat( (context_inputs_2d, extra_context_inputs), 1) # pad and pack encoder_hidden start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()), input_conversation_length[:-1] + 1)), 0) # context_inputs: [batch_size, max_len + 1, num_layers * direction * hidden_size] context_inputs = torch.stack([pad(context_inputs_2d.narrow(0, s, l + 1), max_len + 1) for s, l in zip(start.data.tolist(), input_conversation_length.data.tolist())], 0) # context_inputs_inference: [batch_size, max_len, num_layers * direction * hidden_size] context_inputs_inference = context_inputs[:, 1:, :] context_inputs_inference_flat = torch.cat( [context_inputs_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) # context_inputs_input: [batch_size, max_len, num_layers * direction * hidden_size] context_inputs_input = context_inputs[:, :-1, :] # Standard Gaussian prior conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size])) conv_mu_prior, conv_var_prior = self.conv_prior() if not rl_mode and not decode: if self.config.sentence_drop > 0.0: indices = np.where(np.random.rand(max_len) < self.config.sentence_drop)[0] if len(indices) > 0: context_inputs_input[:, indices, :] = self.unk_sent # context_inference_outputs: [batch_size, max_len, num_directions * context_size] # context_inference_hidden: [num_layers * num_directions, batch_size, hidden_size] context_inference_outputs, context_inference_hidden = self.context_inference( context_inputs, input_conversation_length + 1) # context_inference_hidden: [batch_size, num_layers * num_directions * hidden_size] context_inference_hidden = context_inference_hidden.transpose( 1, 0).contiguous().view(batch_size, -1) conv_mu_posterior, conv_var_posterior = self.conv_posterior(context_inference_hidden) z_conv = conv_mu_posterior + torch.sqrt(conv_var_posterior) * conv_eps log_q_zx_conv = normal_logpdf(z_conv, conv_mu_posterior, conv_var_posterior).sum() log_p_z_conv = normal_logpdf(z_conv, conv_mu_prior, conv_var_prior).sum() kl_div_conv = normal_kl_div(conv_mu_posterior, conv_var_posterior, conv_mu_prior, conv_var_prior).sum() context_init = self.z_conv2context(z_conv).view( self.config.num_layers, batch_size, self.config.context_size) z_conv_expand = z_conv.view(z_conv.size(0), 1, z_conv.size( 1)).expand(z_conv.size(0), max_len, z_conv.size(1)) context_outputs, context_last_hidden = self.context_encoder( torch.cat([context_inputs_input, z_conv_expand], 2), input_conversation_length, hidden=context_init) # flatten outputs # context_outputs: [num_sentences, context_size] context_outputs = torch.cat([context_outputs[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) z_conv_flat = torch.cat( [z_conv_expand[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) sent_mu_prior, sent_var_prior = self.sent_prior(context_outputs, z_conv_flat) eps = to_var(torch.randn((num_sentences, self.config.z_sent_size))) sent_mu_posterior, sent_var_posterior = self.sent_posterior( context_outputs, context_inputs_inference_flat, z_conv_flat) z_sent = sent_mu_posterior + torch.sqrt(sent_var_posterior) * eps log_q_zx_sent = normal_logpdf(z_sent, sent_mu_posterior, sent_var_posterior).sum() log_p_z_sent = normal_logpdf(z_sent, sent_mu_prior, sent_var_prior).sum() # kl_div: [num_sentences] kl_div_sent = normal_kl_div(sent_mu_posterior, sent_var_posterior, sent_mu_prior, sent_var_prior).sum() kl_div = kl_div_conv + kl_div_sent log_q_zx = log_q_zx_conv + log_q_zx_sent log_p_z = log_p_z_conv + log_p_z_sent else: z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps context_init = self.z_conv2context(z_conv).view( self.config.num_layers, batch_size, self.config.context_size) z_conv_expand = z_conv.view(z_conv.size(0), 1, z_conv.size( 1)).expand(z_conv.size(0), max_len, z_conv.size(1)) # context_outputs: [batch_size, max_len, context_size] context_outputs, context_last_hidden = self.context_encoder( torch.cat([context_inputs_input, z_conv_expand], 2), input_conversation_length, hidden=context_init) # flatten outputs # context_outputs: [num_sentences, context_size] context_outputs = torch.cat([context_outputs[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) z_conv_flat = torch.cat( [z_conv_expand[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) sent_mu_prior, sent_var_prior = self.sent_prior(context_outputs, z_conv_flat) eps = to_var(torch.randn((num_sentences, self.config.z_sent_size))) z_sent = sent_mu_prior + torch.sqrt(sent_var_prior) * eps kl_div = None log_p_z = normal_logpdf(z_sent, sent_mu_prior, sent_var_prior).sum() log_p_z += normal_logpdf(z_conv, conv_mu_prior, conv_var_prior).sum() log_q_zx = None # Predict emojis using discriminator. emoji_preds = None if self.config.emotion: emoji_preds = self.context2emoji(context_outputs) # Predict sentence embeddings using discriminator. infersent_preds = None if self.config.infersent: infersent_preds = self.context2infersent(context_outputs) # expand z_conv to all associated sentences z_conv = torch.cat([z.view(1, -1).expand(m.item(), self.config.z_conv_size) for z, m in zip(z_conv, input_conversation_length)]) # latent_context: [num_sentences, context_size + z_sent_size + # z_conv_size] latent_context = torch.cat([context_outputs, z_sent, z_conv], 1) decoder_init = self.context2decoder(latent_context) decoder_init = decoder_init.view(-1, self.decoder.num_layers, self.decoder.hidden_size) decoder_init = decoder_init.transpose(1, 0).contiguous() # train: [batch_size, seq_len, vocab_size] # eval: [batch_size, seq_len] if rl_mode or not decode: decoder_outputs = self.decoder(target_sentences, init_h=decoder_init, decode=decode) return (decoder_outputs, kl_div, log_p_z, log_q_zx, emoji_preds, infersent_preds) else: # prediction: [batch_size, beam_size, max_unroll] prediction, final_score, length = self.decoder.beam_decode( init_h=decoder_init) return (prediction, kl_div, log_p_z, log_q_zx, emoji_preds, infersent_preds)
def conv_prior(self): # Standard gaussian prior return to_var(torch.FloatTensor([0.0])), to_var(torch.FloatTensor([1.0]))
def forward(self, sentences, sentence_length, input_conversation_length, target_sentences, decode=False, extra_context_inputs=None, rl_mode=False): """ Args: sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len] target_sentences: (Variable, LongTensor) [num_sentences, seq_len] Return: decoder_outputs: (Variable, FloatTensor) - train: [batch_size, seq_len, vocab_size] - eval: [batch_size, seq_len] """ batch_size = input_conversation_length.size(0) num_sentences = sentences.size(0) - batch_size max_len = input_conversation_length.data.max().item() # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size] # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size] encoder_outputs, raw_encoder_hidden = self.encoder(sentences, sentence_length) # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size] context_inputs_2d = raw_encoder_hidden.transpose( 1, 0).contiguous().view(num_sentences + batch_size, -1) if self.config.context_input_only: context_inputs_2d = torch.cat( (context_inputs_2d, extra_context_inputs), 1) # pad and pack encoder_hidden start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()), input_conversation_length[:-1] + 1)), 0) # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size] context_inputs = torch.stack([pad(context_inputs_2d.narrow(0, s, l + 1), max_len + 1) for s, l in zip(start.data.tolist(), input_conversation_length.data.tolist())], 0) # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size] context_inputs_inference = context_inputs[:, 1:, :] context_inputs_inference_flat = torch.cat( [context_inputs_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size] context_inputs_input = context_inputs[:, :-1, :] # context_outputs: [batch_size, max_len, context_size] context_outputs_with_targets, context_last_hidden = self.context_encoder( context_inputs_input, input_conversation_length) # flatten outputs # context_outputs: [num_sentences, context_size] context_outputs = torch.cat([context_outputs_with_targets[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) # Stop gradients from flowing from discriminator if only using input if self.config.context_input_only: discriminator_input = context_outputs.detach() else: discriminator_input = context_outputs # Predict emojis using discriminator. emoji_preds = None if self.config.emotion: emoji_preds = self.context2emoji(discriminator_input) # Predict sentence embeddings using discriminator infersent_preds = None if self.config.infersent: infersent_preds = self.context2infersent(discriminator_input) mu_prior, var_prior = self.prior(context_outputs) eps = to_var(torch.randn((num_sentences, self.config.z_sent_size))) if not rl_mode and not decode: mu_posterior, var_posterior = self.posterior( context_outputs, context_inputs_inference_flat) z_sent = mu_posterior + torch.sqrt(var_posterior) * eps log_q_zx = normal_logpdf(z_sent, mu_posterior, var_posterior).sum() log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum() # kl_div: [num_sentneces] kl_div = normal_kl_div(mu_posterior, var_posterior, mu_prior, var_prior) kl_div = torch.sum(kl_div) else: z_sent = mu_prior + torch.sqrt(var_prior) * eps kl_div = None log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum() log_q_zx = None self.z_sent = z_sent latent_context = torch.cat([context_outputs, z_sent], 1) decoder_init = self.context2decoder(latent_context) decoder_init = decoder_init.view(-1, self.decoder.num_layers, self.decoder.hidden_size) decoder_init = decoder_init.transpose(1, 0).contiguous() # train: [batch_size, seq_len, vocab_size] # eval: [batch_size, seq_len] if rl_mode or not decode: decoder_outputs = self.decoder(target_sentences, init_h=decoder_init, decode=decode) return (decoder_outputs, kl_div, log_p_z, log_q_zx, emoji_preds, infersent_preds) else: # prediction: [batch_size, beam_size, max_unroll] prediction, final_score, length = self.decoder.beam_decode( init_h=decoder_init) return (prediction, kl_div, log_p_z, log_q_zx, emoji_preds, infersent_preds)