class BatchQ: def __init__(self, config, val_config): self.config = config self.val_config = val_config # Load experience replay buffer from file self.experience = replay_buffer.CsvReplayBuffer( config.experience_path, raw=config.raw_buffer, history_len=config.max_conversation_length, config=config, max_sentence_length=config.max_sentence_length, rewards=config.rewards, reward_weights=config.reward_weights, model_averaging=config.model_averaging) self.vocab = self.experience.vocab self.config.vocab_size = self.experience.vocab.vocab_size self.action_dim = self.experience.vocab.vocab_size # Check that all required rewards are in the buffer; if not, compute for r in config.rewards: if r not in self.experience.buffer.columns.values: reward_func = getattr(rewards, r) self.experience = reward_func(self.experience) # Build internal hierarchical models self.eval_data = self.get_data_loader() self.build_models() if self.config.load_rl_ckpt: self.load_models() self.q_optimizer = torch.optim.Adam(filter( lambda p: p.requires_grad, self.q_net.model.parameters()), lr=self.config.learning_rate) self.set_up_logging() def q_update(self): """General Q learning update.""" # Sample a batch batch = self.experience.sample(self.config.rl_batch_size) # Run underlying q network to get q value of each word in each # conversation in the batch. Use the same data to run the prior network # and get the rewards based on KL divergence from the prior. q_values, prior_rewards = self.get_q_values(batch) # Compute target Q values. These will include the rewards observed in # the batch (i.e. r + done * gamma * max_a' Q_T(a,s')) with torch.no_grad(): target_q_values = self.get_target_q_values(batch, prior_rewards) loss_func = getattr(F, self.config.q_loss_func) loss = loss_func(q_values, target_q_values) assert not isnan(loss.item()) self.q_loss_batch_history.append(loss.item()) # Optimize the model self.q_optimizer.zero_grad() loss.backward() # Clip gradients - absolutely crucial torch.nn.utils.clip_grad_value_(self.q_net.model.parameters(), self.config.gradient_clip) self.q_optimizer.step() # Update Target Networks tau = self.config.target_update_rate for param, target_param in zip(self.q_net.model.parameters(), self.target_q_net.model.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) def get_q_values(self, batch): """update where states are whole conversations which each have several sentences, and actions are a sentence (series of words). Q values are per word. Target Q values are over the next word in the sentence, or, if at the end of the sentence, the first word in a new sentence after the user response. """ actions = to_var(torch.LongTensor(batch['action'])) # [batch_size] # Prepare inputs to Q network conversations = [ np.concatenate((conv, np.atleast_2d(batch['action'][i]))) for i, conv in enumerate(batch['state']) ] sent_lens = [ np.concatenate((lens, np.atleast_1d(batch['action_lens'][i]))) for i, lens in enumerate(batch['state_lens']) ] target_conversations = [conv[1:] for conv in conversations] conv_lens = [len(c) - 1 for c in conversations] if self.config.model not in VariationalModels: conversations = [conv[:-1] for conv in conversations] sent_lens = np.concatenate([l[:-1] for l in sent_lens]) else: sent_lens = np.concatenate([l for l in sent_lens]) conv_lens = to_var(torch.LongTensor(conv_lens)) # Run Q network. Will produce [num_sentences, max sent len, vocab size] all_q_values = self.run_seq2seq_model(self.q_net, conversations, sent_lens, target_conversations, conv_lens) # Index to get only q values for actions taken (last sentence in each # conversation) start_q = torch.cumsum( torch.cat((to_var(conv_lens.data.new(1).zero_()), conv_lens[:-1])), 0) conv_q_values = torch.stack([ all_q_values[s + l - 1, :, :] for s, l in zip(start_q.data.tolist(), conv_lens.data.tolist()) ], 0) # [num_sentences, max_sent_len, vocab_size] # Limit by actual sentence length (remove padding) and flatten into # long list of words word_q_values = torch.cat([ conv_q_values[i, :l, :] for i, l in enumerate(batch['action_lens']) ], 0) # [total words, vocab_size] word_actions = torch.cat( [actions[i, :l] for i, l in enumerate(batch['action_lens'])], 0) # [total words] # Extract q values corresponding to actions taken q_values = word_q_values.gather( 1, word_actions.unsqueeze(1)).squeeze() # [total words] """ Compute KL metrics """ prior_rewards = None # Get probabilities from policy network q_dists = torch.nn.functional.softmax(word_q_values, 1) q_probs = q_dists.gather(1, word_actions.unsqueeze(1)).squeeze() with torch.no_grad(): # Run pretrained prior network. # [num_sentences, max sent len, vocab size] all_prior_logits = self.run_seq2seq_model(self.pretrained_prior, conversations, sent_lens, target_conversations, conv_lens) # Get relevant actions. [num_sentences, max_sent_len, vocab_size] conv_prior = torch.stack([ all_prior_logits[s + l - 1, :, :] for s, l in zip(start_q.data.tolist(), conv_lens.data.tolist()) ], 0) # Limit by actual sentence length (remove padding) and flatten. # [total words, vocab_size] word_prior_logits = torch.cat([ conv_prior[i, :l, :] for i, l in enumerate(batch['action_lens']) ], 0) # Take the softmax prior_dists = torch.nn.functional.softmax(word_prior_logits, 1) kl_div = F.kl_div(q_dists.log(), prior_dists, reduce=False) # [total words] prior_probs = prior_dists.gather( 1, word_actions.unsqueeze(1)).squeeze() logp_logq = prior_probs.log() - q_probs.log() if self.config.model_averaging: model_avg_sentences = batch['model_averaged_probs'] # Convert to tensors and flatten into [num_words] word_model_avg = torch.cat([ to_var(torch.FloatTensor(m)) for m in model_avg_sentences ], 0) # Compute KL from model-averaged prior prior_rewards = word_model_avg.log() - q_probs.log() # Clip because KL should never be negative, so because we # are subtracting KL, rewards should never be positive prior_rewards = torch.clamp(prior_rewards, max=0.0) elif self.config.kl_control and self.config.kl_calc == 'integral': # Note: we reward the negative KL divergence to ensure the # RL model stays close to the prior prior_rewards = -1.0 * torch.sum(kl_div, dim=1) elif self.config.kl_control: prior_rewards = logp_logq if self.config.kl_control: prior_rewards = prior_rewards * self.config.kl_weight_c self.kl_reward_batch_history.append( torch.sum(prior_rewards).item()) # Track all metrics self.kl_div_batch_history.append(torch.mean(kl_div).item()) self.logp_batch_history.append( torch.mean(prior_probs.log()).item()) self.logp_logq_batch_history.append(torch.mean(logp_logq).item()) return q_values, prior_rewards def get_target_q_values(self, batch, prior_rewards=None): rewards = to_var(torch.FloatTensor(batch['rewards'])) # [batch_size] not_done = to_var(torch.FloatTensor(1 - batch['done'])) # [batch_size] self.sampled_reward_batch_history.append(torch.sum(rewards).item()) # Prepare inputs to target Q network. Append a blank sentence to get # best response at next utterance to user input. (Next state # includes user input). blank_sentence = np.zeros((1, self.config.max_sentence_length)) next_state_convs = [ np.concatenate((conv, blank_sentence)) for conv in batch['next_state'] ] next_state_lens = [ np.concatenate((lens, [1])) for lens in batch['next_state_lens'] ] next_targets = [conv[1:] for conv in next_state_convs] next_conv_lens = [len(c) - 1 for c in next_state_convs] if self.config.model not in VariationalModels: next_state_convs = [conv[:-1] for conv in next_state_convs] next_state_lens = np.concatenate([l[:-1] for l in next_state_lens]) else: next_state_lens = np.concatenate([l for l in next_state_lens]) next_conv_lens = to_var(torch.LongTensor(next_conv_lens)) # [monte_carlo_count, num_sentences, max sent len, vocab size] _mc_target_q_values = [[]] * self.config.monte_carlo_count for t in range(self.config.monte_carlo_count): # Run target Q network. Output is size: # [num_sentences, max sent len, vocab size] if self.config.monte_carlo_count == 1: # In this setting, we don't use dropout out at inference time at all all_target_q_values = self.run_seq2seq_model( self.target_q_net, next_state_convs, next_state_lens, next_targets, next_conv_lens) else: # In this setting, each time we draw a new dropout mask (at inference time) all_target_q_values = self.run_seq2seq_model( self.target_q_net, next_state_convs, next_state_lens, next_targets, next_conv_lens) # Target indexing: last sentence is a blank to get value of next # response. Second last is the user response. 3rd last is models own # actions. Note that targets begin at the 2nd word of each sentence. start_t = torch.cumsum( torch.cat((to_var(next_conv_lens.data.new(1).zero_()), next_conv_lens[:-1])), 0) conv_target_q_values = torch.stack([ all_target_q_values[s + l - 3, 1:, :] for s, l in zip( start_t.data.tolist(), next_conv_lens.data.tolist()) ], 0) # Dimension [num_sentences, max_sent_len - 1, vocab_size] # At the end of a sentence, want value of starting a new response # after user's response. So index into first word of last blank # sentence that was appended to the end of the conversation. next_response_targets = torch.stack([ all_target_q_values[s + l - 1, 0, :] for s, l in zip( start_t.data.tolist(), next_conv_lens.data.tolist()) ], 0) next_response_targets = torch.reshape( next_response_targets, [self.config.rl_batch_size, 1, -1 ]) # [num_sentences, 1, vocab_size] conv_target_q_values = torch.cat( [conv_target_q_values, next_response_targets], 1) # [num_sentences, max_sent_len, vocab_size] # Limit target Q values by conversation length limit_conv_targets = [ conv_target_q_values[i, :l, :] for i, l in enumerate(batch['action_lens']) ] if self.config.psi_learning: # Target is r + gamma * log sum_a' exp(Q_target(s', a')) conv_max_targets = [ torch.distributions.utils.log_sum_exp(c) for c in limit_conv_targets ] target_q_values = torch.cat([ rewards[i] + not_done[i] * self.config.gamma * c.squeeze() for i, c in enumerate(conv_max_targets) ], 0) # [total words] else: # Target is r + gamma * max_a' Q_target(s',a'). Reward and done are # at the level of conversation, so add and multiply in before # flattening and taking max. word_target_q_values = torch.cat([ rewards[i] + not_done[i] * self.config.gamma * c for i, c in enumerate(limit_conv_targets) ], 0) # [total words, vocab_size] target_q_values, _ = word_target_q_values.max(1) _mc_target_q_values[t] = target_q_values mc_target_q_values = torch.stack(_mc_target_q_values, 0) min_target_q_values, _ = mc_target_q_values.min(0) if self.config.kl_control: min_target_q_values += prior_rewards return min_target_q_values def q_learn(self): self.q_loss_history = [] self.q_loss_batch_history = [] self.sampled_reward_history = [] self.sampled_reward_batch_history = [] if self.config.kl_control: self.kl_reward_history = [] self.kl_reward_batch_history = [] # Need to track KL metrics even for baselines for plots self.kl_div_history = [] self.kl_div_batch_history = [] self.logp_history = [] self.logp_batch_history = [] self.logp_logq_history = [] self.logp_logq_batch_history = [] print('Commencing training at step', self.t) while self.t <= self.config.num_steps: self.q_update() # Log metrics if self.t % self.config.log_every_n == 0: self.epoch_q_loss = np.sum(self.q_loss_batch_history) \ / self.config.log_every_n self.q_loss_history.append(self.epoch_q_loss) self.q_loss_batch_history = [] print('Average Q loss at step', self.t, '=', self.epoch_q_loss) self.epoch_sampled_reward = np.sum( self.sampled_reward_batch_history ) / self.config.log_every_n self.sampled_reward_history.append(self.epoch_sampled_reward) self.sampled_reward_batch_history = [] print('\tAverage sampled batch reward =', self.epoch_sampled_reward) if self.config.kl_control: self.epoch_kl_reward = np.sum( self.kl_reward_batch_history) \ / self.config.log_every_n self.kl_reward_history.append(self.epoch_kl_reward) self.kl_reward_batch_history = [] print('\tAverage data prior reward =', self.epoch_kl_reward) # Logging KL for plots self.epoch_kl_div = np.sum( self.kl_div_batch_history) / self.config.log_every_n self.kl_div_history.append(self.epoch_kl_div) self.kl_div_batch_history = [] self.epoch_logp = np.sum( self.logp_batch_history) / self.config.log_every_n self.logp_history.append(self.epoch_logp) self.logp_batch_history = [] self.epoch_logp_logq = np.sum( self.logp_logq_batch_history) / self.config.log_every_n self.logp_logq_history.append(self.epoch_logp_logq) self.logp_logq_batch_history = [] sys.stdout.flush() self.write_summary(self.t) if self.t > 0 and self.t % self.config.save_every_n == 0: self.save_model(self.t) self.t += 1 def build_models(self): config = copy.deepcopy(self.config) # If loading RL checkpoint, ensure it doesn't try to load the ckpt # through Solver if self.config.load_rl_ckpt: config.checkpoint = None if self.config.model in VariationalModels: self.q_net = VariationalSolver(config, None, self.eval_data, vocab=self.vocab, is_train=True) self.target_q_net = VariationalSolver(config, None, self.eval_data, vocab=self.vocab, is_train=True) else: self.q_net = Solver(config, None, self.eval_data, vocab=self.vocab, is_train=True) self.target_q_net = Solver(config, None, self.eval_data, vocab=self.vocab, is_train=True) print('Building Q network') self.q_net.build() print('\nBuilding Target Q network') self.target_q_net.build() if self.config.model in VariationalModels: self.pretrained_prior = VariationalSolver(self.config, None, self.eval_data, vocab=self.vocab, is_train=True) else: self.pretrained_prior = Solver(self.config, None, self.eval_data, vocab=self.vocab, is_train=True) print('Building prior network') self.pretrained_prior.build() # Freeze the weights of the prior so it stays constant self.pretrained_prior.model.eval() for params in self.pretrained_prior.model.parameters(): params.requires_grad = False print('Successfully initialized Q networks') self.t = 0 def run_seq2seq_model(self, q_net, input_conversations, sent_lens, target_conversations, conv_lens): # Prepare the batch sentences = [sent for conv in input_conversations for sent in conv] targets = [sent for conv in target_conversations for sent in conv] if not (np.all(np.isfinite(sentences)) and np.all(np.isfinite(targets)) and np.all(np.isfinite(sent_lens))): print("Input isn't finite") sentences = to_var(torch.LongTensor(sentences)) targets = to_var(torch.LongTensor(targets)) sent_lens = to_var(torch.LongTensor(sent_lens)) # Run Q network q_outputs = q_net.model(sentences, sent_lens, conv_lens, targets, rl_mode=True) return q_outputs[0] # [num_sentences, max_sentence_len, vocab_size] def write_summary(self, t): metrics_to_log = [ 'epoch_q_loss', 'epoch_sampled_reward', 'epoch_kl_div', 'epoch_logp', 'epoch_logp_logq' ] if self.config.kl_control: metrics_to_log.append('epoch_kl_reward') metrics_dict = {} for metric in metrics_to_log: met_val = getattr(self, metric, None) metrics_dict[metric] = met_val if met_val is not None: self.writer.update_loss(loss=met_val, step_i=t, name=metric) # Write pandas csv with metrics to save dir self.df = self.df.append(metrics_dict, ignore_index=True) self.df.to_csv(self.pandas_path) def set_up_logging(self): # Get save path time_now = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') default_save_path = Path( os.path.expanduser('~') + '/dialog/model_checkpoints/rl/') # Folder for type of RL used experiment_name = self.config.experiment_name if experiment_name is None: # if self.config.double_q: experiment_name = 'double_q' if self.config.model_averaging: experiment_name = 'model_averaging' elif self.config.kl_control: experiment_name = 'kl_control' if self.config.kl_calc == 'sample': experiment_name += '_sample' else: experiment_name = 'batch_q' if self.config.psi_learning: experiment_name += '/psi_learning' if self.config.monte_carlo_count > 1: experiment_name = 'monte_carlo_targets/' + experiment_name # Folder for type of rewards used extra_save_dir = self.config.extra_save_dir if not extra_save_dir: if len(self.config.rewards) == 1: extra_save_dir = self.config.rewards[0] else: extra_save_dir = 'reward_combo' # Folder for which model was used extra_model_desc = "" if self.config.context_input_only: extra_model_desc = 'input_only_' if self.config.emotion and 'input_only' not in extra_model_desc: extra_model_desc += "emotion_" if self.config.infersent and 'input_only' not in extra_model_desc: extra_model_desc += "infersent_" # Make save path self.save_dir = default_save_path.joinpath( self.q_net.config.data, extra_save_dir, experiment_name, extra_model_desc + self.q_net.config.model, time_now) # Make directory and save config print("Saving output to", self.save_dir) os.makedirs(self.save_dir, exist_ok=True) with open(os.path.join(self.save_dir, 'config.txt'), 'w') as f: print(self.config, file=f) # Make loggers self.writer = TensorboardWriter(self.save_dir) self.pandas_path = os.path.join(self.save_dir, "metrics.csv") self.df = pd.DataFrame() def save_model(self, t): """Save parameters to checkpoint""" ckpt_path = os.path.join(self.save_dir, f'q_net{t}.pkl') print(f'Save parameters to {ckpt_path}') torch.save(self.q_net.model.state_dict(), ckpt_path) ckpt_path = os.path.join(self.save_dir, f'target_q_net{t}.pkl') torch.save(self.target_q_net.model.state_dict(), ckpt_path) def load_models(self): """Load parameters from RL checkpoint""" # Override specific checkpoint with particular one base_checkpoint_dir = str(Path(self.config.checkpoint).parent) if self.config.rl_ckpt_epoch is not None: q_ckpt_path = os.path.join( base_checkpoint_dir, 'q_net' + str(self.config.rl_ckpt_epoch) + '.pkl') target_q_ckpt_path = os.path.join( base_checkpoint_dir, 'target_q_net' + str(self.config.rl_ckpt_epoch) + '.pkl') self.t = int(self.config.rl_ckpt_epoch) else: ckpt_file = self.config.checkpoint.replace(base_checkpoint_dir, '') ckpt_file = ckpt_file.replace('/', '') ckpt_num = ckpt_file[len('q_net'):ckpt_file.find('.')] q_ckpt_path = self.config.checkpoint target_q_ckpt_path = os.path.join( base_checkpoint_dir, 'target_q_net' + ckpt_num + '.pkl') self.t = int(ckpt_num) print(f'Loading parameters for Q net from {q_ckpt_path}') q_ckpt = torch.load(q_ckpt_path) q_ckpt = convert_old_checkpoint_format(q_ckpt) self.q_net.model.load_state_dict(q_ckpt) print(f'Loading parameters for target Q net from {target_q_ckpt_path}') target_q_ckpt = torch.load(target_q_ckpt_path) target_q_ckpt = convert_old_checkpoint_format(target_q_ckpt) self.target_q_net.model.load_state_dict(target_q_ckpt) # Ensure weights are initialized to be on the GPU when necessary if torch.cuda.is_available(): print('Converting checkpointed model to cuda tensors') self.q_net.model.cuda() self.target_q_net.model.cuda() def get_data_loader(self): # If checkpoint is for an emotion model, load that pickle file emotion_sentences = None if self.config.emotion: emotion_sentences = load_pickle(self.val_config.emojis_path) # Load infersent embeddings if necessary infersent_sentences = None if self.config.infersent: print('Loading infersent sentence embeddings...') infersent_sentences = load_pickle(self.val_config.infersent_path) embedding_size = infersent_sentences[0][0].shape[0] self.config.infersent_output_size = embedding_size self.val_config.infersent_output_size = embedding_size return get_loader( sentences=load_pickle(self.val_config.sentences_path), conversation_length=load_pickle( self.val_config.conversation_length_path), sentence_length=load_pickle(self.val_config.sentence_length_path), vocab=self.vocab, batch_size=self.config.batch_size, emojis=emotion_sentences, infersent=infersent_sentences) def interact(self): print("Commencing interaction with bot trained with RL") self.q_net.interact( max_sentence_length=self.config.max_sentence_length, max_conversation_length=self.config.max_conversation_length, sample_by='priority', debug=True, print_history=True)
class DBCQ: def __init__(self, prior_config, rl_config, beam_size=5): self.prior_config = prior_config self.rl_config = rl_config self.rl_config.beam_size = beam_size print('Loading Vocabulary...') self.vocab = Vocab() self.vocab.load(prior_config.word2id_path, prior_config.id2word_path) self.prior_config.vocab_size = self.vocab.vocab_size self.rl_config.vocab_size = self.vocab.vocab_size print(f'Vocabulary size: {self.vocab.vocab_size}') self.eval_data = self.get_data_loader() self.build_models() def build_models(self): rl_config = copy.deepcopy(self.rl_config) rl_config.checkpoint = None print('Building Q network') if rl_config.model in VariationalModels: self.q_net = VariationalSolver( rl_config, None, self.eval_data, vocab=self.vocab, is_train=False) else: self.q_net = Solver( rl_config, None, self.eval_data, vocab=self.vocab, is_train=False) self.q_net.build() self.load_q_network() print('Building prior network') if self.prior_config.model in VariationalModels: self.pretrained_prior = VariationalSolver( self.prior_config, None, self.eval_data, vocab=self.vocab, is_train=False) else: self.pretrained_prior = Solver( self.prior_config, None, self.eval_data, vocab=self.vocab, is_train=False) self.pretrained_prior.build() # Freeze the weights so they stay constant self.pretrained_prior.model.eval() for params in self.pretrained_prior.model.parameters(): params.requires_grad = False self.q_net.model.eval() for params in self.q_net.model.parameters(): params.requires_grad = False def load_q_network(self): """Load parameters from RL checkpoint""" print(f'Loading parameters for Q net from {self.rl_config.checkpoint}') q_ckpt = torch.load(self.rl_config.checkpoint) q_ckpt = convert_old_checkpoint_format(q_ckpt) self.q_net.model.load_state_dict(q_ckpt) # Ensure weights are initialized to be on the GPU when necessary if torch.cuda.is_available(): print('Converting checkpointed model to cuda tensors') self.q_net.model.cuda() def get_data_loader(self): # If checkpoint is for an emotion model, load that pickle file emotion_sentences = None if self.prior_config.emotion: emotion_sentences = load_pickle(self.prior_config.emojis_path) # Load infersent embeddings if necessary infersent_sentences = None if self.prior_config.infersent: print('Loading infersent sentence embeddings...') infersent_sentences = load_pickle(self.prior_config.infersent_path) embedding_size = infersent_sentences[0][0].shape[0] self.prior_config.infersent_output_size = embedding_size return get_loader( sentences=load_pickle(self.prior_config.sentences_path), conversation_length=load_pickle( self.prior_config.conversation_length_path), sentence_length=load_pickle(self.prior_config.sentence_length_path), vocab=self.vocab, batch_size=self.prior_config.batch_size, emojis=emotion_sentences, infersent=infersent_sentences) def interact(self, max_conversation_length=5, sample_by='priority', debug=True): model_name = self.prior_config.model context_sentences = [] print("Time to start a conversation with the chatbot! It's name is", model_name) username = input("What is your name? ") print("Let's start chatting. You can type 'quit' at any time to quit.") utterance = input("Input: ") print("\033[1A\033[K") # Erases last line of output while (utterance.lower() != 'quit' and utterance.lower() != 'exit'): # Process utterance sentences = utterance.split('/') # Code and decode user input to show how it is transformed for model coded, lens = self.pretrained_prior.process_user_input(sentences) decoded = [self.vocab.decode(sent) for i, sent in enumerate( coded) if i < lens[i]] print(username + ':', '. '.join(decoded)) # Append to conversation context_sentences.extend(sentences) gen_response = self.generate_response_to_input( context_sentences, max_conversation_length, sample_by=sample_by, debug=debug) # Append generated sentences to conversation context_sentences.append(gen_response) # Print and get next user input print("\n" + model_name + ": " + gen_response) utterance = input("Input: ") print("\033[1A\033[K") def process_raw_text_into_input(self, raw_text_sentences, max_conversation_length=5, debug=False,): sentences, lens = self.pretrained_prior.process_user_input( raw_text_sentences, self.rl_config.max_sentence_length) # Remove any sentences of length 0 sentences = [sent for i, sent in enumerate(sentences) if lens[i] > 0] good_raw_sentences = [sent for i, sent in enumerate( raw_text_sentences) if lens[i] > 0] lens = [l for l in lens if l > 0] # Trim conversation to max length sentences = sentences[-max_conversation_length:] lens = lens[-max_conversation_length:] good_raw_sentences = good_raw_sentences[-max_conversation_length:] convo_length = len(sentences) # Convert to torch variables input_sentences = to_var(torch.LongTensor(sentences)) input_sentence_length = to_var(torch.LongTensor(lens)) input_conversation_length = to_var(torch.LongTensor([convo_length])) if debug: print('\n**Conversation history:**') for sent in sentences: print(self.vocab.decode(list(sent))) return (input_sentences, input_sentence_length, input_conversation_length) def duplicate_context_for_beams(self, sentences, sent_lens, conv_lens, beams): conv_lens = conv_lens.repeat(len(beams)) # [beam_size * sentences, sentence_len] if len(sentences) > 1: targets = torch.cat( [torch.cat([sentences[1:,:], beams[i,:].unsqueeze(0)], 0) for i in range(len(beams))], 0) else: targets = beams # HRED if self.rl_config.model not in VariationalModels: sent_lens = sent_lens.repeat(len(beams)) return sentences, sent_lens, conv_lens, targets # VHRED, VHCR new_sentences = torch.cat( [torch.cat([sentences, beams[i,:].unsqueeze(0)], 0) for i in range(len(beams))], 0) new_len = to_var(torch.LongTensor([self.rl_config.max_sentence_length])) sent_lens = torch.cat( [torch.cat([sent_lens, new_len], 0) for i in range(len(beams))]) return new_sentences, sent_lens, conv_lens, targets def generate_response_to_input(self, raw_text_sentences, max_conversation_length=5, sample_by='priority', emojize=True, debug=True): with torch.no_grad(): (input_sentences, input_sent_lens, input_conv_lens) = self.process_raw_text_into_input( raw_text_sentences, debug=debug, max_conversation_length=max_conversation_length) # Initialize a tensor for beams beams = to_var(torch.LongTensor( np.ones((self.rl_config.beam_size, self.rl_config.max_sentence_length)))) # Create a batch with the context duplicated for each beam (sentences, sent_lens, conv_lens, targets) = self.duplicate_context_for_beams( input_sentences, input_sent_lens, input_conv_lens, beams) # Continuously feed beam sentences into networks to sample the next # best word, add that to the beam, and continue for i in range(self.rl_config.max_sentence_length): # Run both models to obtain logits prior_output = self.pretrained_prior.model( sentences, sent_lens, conv_lens, targets, rl_mode=True) all_prior_logits = prior_output[0] q_output = self.q_net.model( sentences, sent_lens, conv_lens, targets, rl_mode=True) all_q_logits = q_output[0] # Select only those logits for next word q_logits = all_q_logits[:, i, :].squeeze() prior_logits = all_prior_logits[:, i, :].squeeze() # Get prior distribution for next word in each beam prior_dists = torch.nn.functional.softmax(prior_logits, 1) for b in range(self.rl_config.beam_size): # Sample from the prior bcq_n times for each beam dist = torch.distributions.Categorical(prior_dists[b,:]) sampled_idxs = dist.sample_n(self.rl_config.bcq_n) # Select sample with highest q value q_vals = torch.stack( [q_logits[b, idx] for idx in sampled_idxs]) _, best_word_i = torch.max(q_vals, 0) best_word = sampled_idxs[best_word_i] # Update beams beams[b, i] = best_word (sentences, sent_lens, conv_lens, targets) = self.duplicate_context_for_beams( input_sentences, input_sent_lens, input_conv_lens, beams) generated_sentences = beams.cpu().numpy() if debug: print('\n**All generated responses:**') for gen in generated_sentences: print(detokenize(self.vocab.decode(list(gen)))) gen_response = self.pretrained_prior.select_best_generated_response( generated_sentences, sample_by, beam_size=self.rl_config.beam_size) decoded_response = self.vocab.decode(list(gen_response)) decoded_response = detokenize(decoded_response) if emojize: inferred_emojis = self.pretrained_prior.botmoji.emojize_text( raw_text_sentences[-1], 5, 0.07) decoded_response = inferred_emojis + " " + decoded_response return decoded_response
class Chatbot(ABC): def __init__(self, id, name, checkpoint_path, max_conversation_length=5, max_sentence_length=30, is_test_bot=False, rl=False, safe_mode=True): """ All chatbots should extend this class and be registered with the @registerbot decorator :param id: An id string, must be unique! :param name: A user-friendly string shown to the end user to identify the chatbot. Should be unique. :param checkpoint_path: Directory where the trained model checkpoint is saved. :param max_conversation_length: Maximum number of conversation turns to condition on. :param max_sentence_length: Maximum number of tokens per sentence. :param is_test_bot: If True, this bot it can be chosen from the list of bots you see at /dialogadmins screen, but will never be randomly assigned to users landing on the home page. """ self.id = id self.name = name self.checkpoint_path = checkpoint_path self.max_conversation_length = max_conversation_length self.max_sentence_length = max_sentence_length self.is_test_bot = is_test_bot self.safe_mode = safe_mode print("\n\nCreating chatbot", name) self.config = get_config_from_dir(checkpoint_path, mode='test', load_rl_ckpt=rl) self.config.beam_size = 5 print('Loading Vocabulary...') self.vocab = Vocab() self.vocab.load(self.config.word2id_path, self.config.id2word_path) print(f'Vocabulary size: {self.vocab.vocab_size}') self.config.vocab_size = self.vocab.vocab_size # If checkpoint is for an emotion model, load that pickle file emotion_sentences = None if self.config.emotion: emotion_sentences = load_pickle(self.config.emojis_path) # Load infersent embeddings if necessary infersent_sentences = None if self.config.infersent: print('Loading infersent sentence embeddings...') infersent_sentences = load_pickle(self.config.infersent_path) embedding_size = infersent_sentences[0][0].shape[0] self.config.infersent_output_size = embedding_size self.data_loader = get_loader( sentences=load_pickle(self.config.sentences_path), conversation_length=load_pickle( self.config.conversation_length_path), sentence_length=load_pickle(self.config.sentence_length_path), vocab=self.vocab, batch_size=self.config.batch_size, emojis=emotion_sentences) if self.config.model in VariationalModels: self.solver = VariationalSolver(self.config, None, self.data_loader, vocab=self.vocab, is_train=False) elif self.config.model == 'Transformer': self.solver = ParlAISolver(self.config) else: self.solver = Solver(self.config, None, self.data_loader, vocab=self.vocab, is_train=False) self.solver.build() def handle_messages(self, messages): """ Takes a list of messages, and combines those with magic to return a response string :param messages: list of strings :return: string """ greetings = [ "hey , how are you ?", "hi , how 's it going ?", "hey , what 's up ?", "hi . how are you ?", "hello , how are you doing today ? ", "hello . how are things with you ?", "hey ! so, tell me about yourself .", "hi . nice to meet you ." ] # Check for no response if len(messages) == 0: # Respond with canned greeting response return np.random.choice(greetings) # Check for overly short intro messages if len(messages) < 2 and len(messages[0]) <= 6: # 6 for "hello." first_m = messages[0].lower() if 'hi' in first_m or 'hey' in first_m or 'hello' in first_m: # Respond with canned greeting response return np.random.choice(greetings) response = self.solver.generate_response_to_input( messages, max_conversation_length=self.max_conversation_length, emojize=True, debug=False) # Manually remove inappropriate language from response. # WARNING: the following code contains inappropriate language if self.safe_mode: response = response.replace("f*g", "<unknown>") response = response.replace("gays", "<unknown>") response = response.replace("c**t", "%@#$") response = response.replace("f**k", "%@#$") response = response.replace("shit", "%@#$") response = response.replace("dyke", "%@#$") response = response.replace("hell", "heck") response = response.replace("dick", "d***") response = response.replace("bitch", "%@#$") return response
class Chatbot(ABC): def _init_(self, id, name, checkpoint_path, max_conversation_lenght = 5, max_sentence_lenght = 30, is_test_bot = False, rl = False, safe_mode = True): """ All chatbots should extend this class and be registered with the @registerbot decorator :param id: An id string, must be unique! :param name: a user-friendly string shown to the end user to identify the chatbot. Should be unique. :param checkpoint_path: Directory where the trained model checkpoint is saved. :param max_conversation_length: Maximum number of conversation turns to condition on. :param max_sentence_length: Maximum number of tokens per sentence. :param is_test_bot: If True, this bot it can be chosen from the list of bots you see at /dialogadmins screen, but will never be randmly assigned to users landing on the home oage. """ self.id = id self.name = name self.checkpoint_path = checkpoint_path self.max_conversation_length = max_conversation_length self.max_sentence_length = max.sentence_length self.is_teste_bot = is_test_bot self.safe_mode = safe_mode print("\n\nCretaing Chatbot", name) self.config = get_config_from_dir(checkpoint_path, mode = 'test', load_rl_ckpt = rl) self.config.beam_size = 5 print ("Carregando vocabulario") self.vocab = Vocab() self.vocab.load(self.config.word2id_path, self.config.id2word_path) print(f'Vocabulary size: {self.vocab.vocab_size}') self.config.vocab_size = sel.vocab.vocab_size #if checkpoint is for an emotion model, load that pickle file emotion_sentences = None if self.config.emotion: emotion_sentences = load_pickle(self.config>emojis_path) #load inferest embeddings if necessary infersent_sentences = None if self.config.infersent: print("Loading infersent sentences embeddings...") infersent_sentences = load_pickles(self.config.infersent_path) embedding_size = infersent_sentences [0] [0].shape[0] self.config.infersent_output_size = embedding_size self.data_loader = get_loader( sentences = load_pickle(self.config.sentences_path), conversation_length = load_pickle(self.config.conversation_length_path), sentence_length = load_pickle(self.config.sentence_legth_path), vocab = self.vocab, batch_size = self.config.batch_size, emojis = emotion_sentences) if self.config.model in VariationalModels: self.solver = VariationalSolver(self.config, None, self.data_loader, vocab = self.vocab, is_train = False) elif self.config.model == "Transformar": self.solver = ParlAISolver(self.config) else: self.solver = Solver(self.config, None, self.data_loader, vocab = self.vocab, is_train = False) self.solver.build()