示例#1
0
class BatchQ:
    def __init__(self, config, val_config):
        self.config = config
        self.val_config = val_config

        # Load experience replay buffer from file
        self.experience = replay_buffer.CsvReplayBuffer(
            config.experience_path,
            raw=config.raw_buffer,
            history_len=config.max_conversation_length,
            config=config,
            max_sentence_length=config.max_sentence_length,
            rewards=config.rewards,
            reward_weights=config.reward_weights,
            model_averaging=config.model_averaging)
        self.vocab = self.experience.vocab
        self.config.vocab_size = self.experience.vocab.vocab_size
        self.action_dim = self.experience.vocab.vocab_size

        # Check that all required rewards are in the buffer; if not, compute
        for r in config.rewards:
            if r not in self.experience.buffer.columns.values:
                reward_func = getattr(rewards, r)
                self.experience = reward_func(self.experience)

        # Build internal hierarchical models
        self.eval_data = self.get_data_loader()
        self.build_models()

        if self.config.load_rl_ckpt:
            self.load_models()

        self.q_optimizer = torch.optim.Adam(filter(
            lambda p: p.requires_grad, self.q_net.model.parameters()),
                                            lr=self.config.learning_rate)

        self.set_up_logging()

    def q_update(self):
        """General Q learning update."""
        # Sample a batch
        batch = self.experience.sample(self.config.rl_batch_size)

        # Run underlying q network to get q value of each word in each
        # conversation in the batch. Use the same data to run the prior network
        # and get the rewards based on KL divergence from the prior.
        q_values, prior_rewards = self.get_q_values(batch)

        # Compute target Q values. These will include the rewards observed in
        # the batch (i.e. r + done * gamma * max_a' Q_T(a,s'))
        with torch.no_grad():
            target_q_values = self.get_target_q_values(batch, prior_rewards)

        loss_func = getattr(F, self.config.q_loss_func)
        loss = loss_func(q_values, target_q_values)

        assert not isnan(loss.item())
        self.q_loss_batch_history.append(loss.item())

        # Optimize the model
        self.q_optimizer.zero_grad()
        loss.backward()

        # Clip gradients - absolutely crucial
        torch.nn.utils.clip_grad_value_(self.q_net.model.parameters(),
                                        self.config.gradient_clip)

        self.q_optimizer.step()

        # Update Target Networks
        tau = self.config.target_update_rate
        for param, target_param in zip(self.q_net.model.parameters(),
                                       self.target_q_net.model.parameters()):
            target_param.data.copy_(tau * param.data +
                                    (1 - tau) * target_param.data)

    def get_q_values(self, batch):
        """update where states are whole conversations which
        each have several sentences, and actions are a sentence (series of 
        words). Q values are per word. Target Q values are over the next word 
        in the sentence, or, if at the end of the sentence, the first word in a
        new sentence after the user response.
        """
        actions = to_var(torch.LongTensor(batch['action']))  # [batch_size]

        # Prepare inputs to Q network
        conversations = [
            np.concatenate((conv, np.atleast_2d(batch['action'][i])))
            for i, conv in enumerate(batch['state'])
        ]
        sent_lens = [
            np.concatenate((lens, np.atleast_1d(batch['action_lens'][i])))
            for i, lens in enumerate(batch['state_lens'])
        ]
        target_conversations = [conv[1:] for conv in conversations]
        conv_lens = [len(c) - 1 for c in conversations]
        if self.config.model not in VariationalModels:
            conversations = [conv[:-1] for conv in conversations]
            sent_lens = np.concatenate([l[:-1] for l in sent_lens])
        else:
            sent_lens = np.concatenate([l for l in sent_lens])
        conv_lens = to_var(torch.LongTensor(conv_lens))

        # Run Q network. Will produce [num_sentences, max sent len, vocab size]
        all_q_values = self.run_seq2seq_model(self.q_net, conversations,
                                              sent_lens, target_conversations,
                                              conv_lens)

        # Index to get only q values for actions taken (last sentence in each
        # conversation)
        start_q = torch.cumsum(
            torch.cat((to_var(conv_lens.data.new(1).zero_()), conv_lens[:-1])),
            0)
        conv_q_values = torch.stack([
            all_q_values[s + l - 1, :, :]
            for s, l in zip(start_q.data.tolist(), conv_lens.data.tolist())
        ], 0)  # [num_sentences, max_sent_len, vocab_size]

        # Limit by actual sentence length (remove padding) and flatten into
        # long list of words
        word_q_values = torch.cat([
            conv_q_values[i, :l, :] for i, l in enumerate(batch['action_lens'])
        ], 0)  # [total words, vocab_size]
        word_actions = torch.cat(
            [actions[i, :l] for i, l in enumerate(batch['action_lens'])],
            0)  # [total words]

        # Extract q values corresponding to actions taken
        q_values = word_q_values.gather(
            1, word_actions.unsqueeze(1)).squeeze()  # [total words]
        """ Compute KL metrics """
        prior_rewards = None

        # Get probabilities from policy network
        q_dists = torch.nn.functional.softmax(word_q_values, 1)
        q_probs = q_dists.gather(1, word_actions.unsqueeze(1)).squeeze()

        with torch.no_grad():
            # Run pretrained prior network.
            # [num_sentences, max sent len, vocab size]
            all_prior_logits = self.run_seq2seq_model(self.pretrained_prior,
                                                      conversations, sent_lens,
                                                      target_conversations,
                                                      conv_lens)

            # Get relevant actions. [num_sentences, max_sent_len, vocab_size]
            conv_prior = torch.stack([
                all_prior_logits[s + l - 1, :, :]
                for s, l in zip(start_q.data.tolist(), conv_lens.data.tolist())
            ], 0)

            # Limit by actual sentence length (remove padding) and flatten.
            # [total words, vocab_size]
            word_prior_logits = torch.cat([
                conv_prior[i, :l, :]
                for i, l in enumerate(batch['action_lens'])
            ], 0)

            # Take the softmax
            prior_dists = torch.nn.functional.softmax(word_prior_logits, 1)

            kl_div = F.kl_div(q_dists.log(), prior_dists, reduce=False)

            # [total words]
            prior_probs = prior_dists.gather(
                1, word_actions.unsqueeze(1)).squeeze()
            logp_logq = prior_probs.log() - q_probs.log()

            if self.config.model_averaging:
                model_avg_sentences = batch['model_averaged_probs']

                # Convert to tensors and flatten into [num_words]
                word_model_avg = torch.cat([
                    to_var(torch.FloatTensor(m)) for m in model_avg_sentences
                ], 0)

                # Compute KL from model-averaged prior
                prior_rewards = word_model_avg.log() - q_probs.log()

                # Clip because KL should never be negative, so because we
                # are subtracting KL, rewards should never be positive
                prior_rewards = torch.clamp(prior_rewards, max=0.0)

            elif self.config.kl_control and self.config.kl_calc == 'integral':
                # Note: we reward the negative KL divergence to ensure the
                # RL model stays close to the prior
                prior_rewards = -1.0 * torch.sum(kl_div, dim=1)
            elif self.config.kl_control:
                prior_rewards = logp_logq

            if self.config.kl_control:
                prior_rewards = prior_rewards * self.config.kl_weight_c
                self.kl_reward_batch_history.append(
                    torch.sum(prior_rewards).item())

            # Track all metrics
            self.kl_div_batch_history.append(torch.mean(kl_div).item())
            self.logp_batch_history.append(
                torch.mean(prior_probs.log()).item())
            self.logp_logq_batch_history.append(torch.mean(logp_logq).item())

        return q_values, prior_rewards

    def get_target_q_values(self, batch, prior_rewards=None):
        rewards = to_var(torch.FloatTensor(batch['rewards']))  # [batch_size]
        not_done = to_var(torch.FloatTensor(1 - batch['done']))  # [batch_size]
        self.sampled_reward_batch_history.append(torch.sum(rewards).item())

        # Prepare inputs to target Q network. Append a blank sentence to get
        # best response at next utterance to user input. (Next state
        # includes user input).
        blank_sentence = np.zeros((1, self.config.max_sentence_length))
        next_state_convs = [
            np.concatenate((conv, blank_sentence))
            for conv in batch['next_state']
        ]
        next_state_lens = [
            np.concatenate((lens, [1])) for lens in batch['next_state_lens']
        ]
        next_targets = [conv[1:] for conv in next_state_convs]
        next_conv_lens = [len(c) - 1 for c in next_state_convs]
        if self.config.model not in VariationalModels:
            next_state_convs = [conv[:-1] for conv in next_state_convs]
            next_state_lens = np.concatenate([l[:-1] for l in next_state_lens])
        else:
            next_state_lens = np.concatenate([l for l in next_state_lens])
        next_conv_lens = to_var(torch.LongTensor(next_conv_lens))

        # [monte_carlo_count, num_sentences, max sent len, vocab size]
        _mc_target_q_values = [[]] * self.config.monte_carlo_count
        for t in range(self.config.monte_carlo_count):
            # Run target Q network. Output is size:
            # [num_sentences, max sent len, vocab size]
            if self.config.monte_carlo_count == 1:
                # In this setting, we don't use dropout out at inference time at all
                all_target_q_values = self.run_seq2seq_model(
                    self.target_q_net, next_state_convs, next_state_lens,
                    next_targets, next_conv_lens)
            else:
                # In this setting, each time we draw a new dropout mask (at inference time)
                all_target_q_values = self.run_seq2seq_model(
                    self.target_q_net, next_state_convs, next_state_lens,
                    next_targets, next_conv_lens)

            # Target indexing: last sentence is a blank to get value of next
            # response. Second last is the user response. 3rd last is models own
            # actions. Note that targets begin at the 2nd word of each sentence.
            start_t = torch.cumsum(
                torch.cat((to_var(next_conv_lens.data.new(1).zero_()),
                           next_conv_lens[:-1])), 0)
            conv_target_q_values = torch.stack([
                all_target_q_values[s + l - 3, 1:, :] for s, l in zip(
                    start_t.data.tolist(), next_conv_lens.data.tolist())
            ], 0)  # Dimension [num_sentences, max_sent_len - 1, vocab_size]

            # At the end of a sentence, want value of starting a new response
            # after user's response. So index into first word of last blank
            # sentence that was appended to the end of the conversation.
            next_response_targets = torch.stack([
                all_target_q_values[s + l - 1, 0, :] for s, l in zip(
                    start_t.data.tolist(), next_conv_lens.data.tolist())
            ], 0)
            next_response_targets = torch.reshape(
                next_response_targets, [self.config.rl_batch_size, 1, -1
                                        ])  # [num_sentences, 1, vocab_size]
            conv_target_q_values = torch.cat(
                [conv_target_q_values, next_response_targets],
                1)  # [num_sentences, max_sent_len, vocab_size]

            # Limit target Q values by conversation length
            limit_conv_targets = [
                conv_target_q_values[i, :l, :]
                for i, l in enumerate(batch['action_lens'])
            ]

            if self.config.psi_learning:
                # Target is r + gamma * log sum_a' exp(Q_target(s', a'))
                conv_max_targets = [
                    torch.distributions.utils.log_sum_exp(c)
                    for c in limit_conv_targets
                ]
                target_q_values = torch.cat([
                    rewards[i] + not_done[i] * self.config.gamma * c.squeeze()
                    for i, c in enumerate(conv_max_targets)
                ], 0)  # [total words]
            else:
                # Target is r + gamma * max_a' Q_target(s',a'). Reward and done are
                # at the level of conversation, so add and multiply in before
                # flattening and taking max.
                word_target_q_values = torch.cat([
                    rewards[i] + not_done[i] * self.config.gamma * c
                    for i, c in enumerate(limit_conv_targets)
                ], 0)  # [total words, vocab_size]
                target_q_values, _ = word_target_q_values.max(1)

            _mc_target_q_values[t] = target_q_values
        mc_target_q_values = torch.stack(_mc_target_q_values, 0)

        min_target_q_values, _ = mc_target_q_values.min(0)

        if self.config.kl_control:
            min_target_q_values += prior_rewards

        return min_target_q_values

    def q_learn(self):
        self.q_loss_history = []
        self.q_loss_batch_history = []
        self.sampled_reward_history = []
        self.sampled_reward_batch_history = []
        if self.config.kl_control:
            self.kl_reward_history = []
            self.kl_reward_batch_history = []

        # Need to track KL metrics even for baselines for plots
        self.kl_div_history = []
        self.kl_div_batch_history = []
        self.logp_history = []
        self.logp_batch_history = []
        self.logp_logq_history = []
        self.logp_logq_batch_history = []

        print('Commencing training at step', self.t)
        while self.t <= self.config.num_steps:
            self.q_update()

            # Log metrics
            if self.t % self.config.log_every_n == 0:
                self.epoch_q_loss = np.sum(self.q_loss_batch_history) \
                    / self.config.log_every_n
                self.q_loss_history.append(self.epoch_q_loss)
                self.q_loss_batch_history = []
                print('Average Q loss at step', self.t, '=', self.epoch_q_loss)

                self.epoch_sampled_reward = np.sum(
                    self.sampled_reward_batch_history
                ) / self.config.log_every_n
                self.sampled_reward_history.append(self.epoch_sampled_reward)
                self.sampled_reward_batch_history = []
                print('\tAverage sampled batch reward =',
                      self.epoch_sampled_reward)

                if self.config.kl_control:
                    self.epoch_kl_reward = np.sum(
                        self.kl_reward_batch_history) \
                        / self.config.log_every_n
                    self.kl_reward_history.append(self.epoch_kl_reward)
                    self.kl_reward_batch_history = []
                    print('\tAverage data prior reward =',
                          self.epoch_kl_reward)

                # Logging KL for plots
                self.epoch_kl_div = np.sum(
                    self.kl_div_batch_history) / self.config.log_every_n
                self.kl_div_history.append(self.epoch_kl_div)
                self.kl_div_batch_history = []
                self.epoch_logp = np.sum(
                    self.logp_batch_history) / self.config.log_every_n
                self.logp_history.append(self.epoch_logp)
                self.logp_batch_history = []
                self.epoch_logp_logq = np.sum(
                    self.logp_logq_batch_history) / self.config.log_every_n
                self.logp_logq_history.append(self.epoch_logp_logq)
                self.logp_logq_batch_history = []

                sys.stdout.flush()
                self.write_summary(self.t)

            if self.t > 0 and self.t % self.config.save_every_n == 0:
                self.save_model(self.t)

            self.t += 1

    def build_models(self):
        config = copy.deepcopy(self.config)

        # If loading RL checkpoint, ensure it doesn't try to load the ckpt
        # through Solver
        if self.config.load_rl_ckpt:
            config.checkpoint = None

        if self.config.model in VariationalModels:
            self.q_net = VariationalSolver(config,
                                           None,
                                           self.eval_data,
                                           vocab=self.vocab,
                                           is_train=True)
            self.target_q_net = VariationalSolver(config,
                                                  None,
                                                  self.eval_data,
                                                  vocab=self.vocab,
                                                  is_train=True)
        else:
            self.q_net = Solver(config,
                                None,
                                self.eval_data,
                                vocab=self.vocab,
                                is_train=True)
            self.target_q_net = Solver(config,
                                       None,
                                       self.eval_data,
                                       vocab=self.vocab,
                                       is_train=True)
        print('Building Q network')
        self.q_net.build()

        print('\nBuilding Target Q network')
        self.target_q_net.build()

        if self.config.model in VariationalModels:
            self.pretrained_prior = VariationalSolver(self.config,
                                                      None,
                                                      self.eval_data,
                                                      vocab=self.vocab,
                                                      is_train=True)
        else:
            self.pretrained_prior = Solver(self.config,
                                           None,
                                           self.eval_data,
                                           vocab=self.vocab,
                                           is_train=True)
        print('Building prior network')
        self.pretrained_prior.build()

        # Freeze the weights of the prior so it stays constant
        self.pretrained_prior.model.eval()
        for params in self.pretrained_prior.model.parameters():
            params.requires_grad = False

        print('Successfully initialized Q networks')
        self.t = 0

    def run_seq2seq_model(self, q_net, input_conversations, sent_lens,
                          target_conversations, conv_lens):
        # Prepare the batch
        sentences = [sent for conv in input_conversations for sent in conv]
        targets = [sent for conv in target_conversations for sent in conv]

        if not (np.all(np.isfinite(sentences)) and np.all(np.isfinite(targets))
                and np.all(np.isfinite(sent_lens))):
            print("Input isn't finite")

        sentences = to_var(torch.LongTensor(sentences))
        targets = to_var(torch.LongTensor(targets))
        sent_lens = to_var(torch.LongTensor(sent_lens))

        # Run Q network
        q_outputs = q_net.model(sentences,
                                sent_lens,
                                conv_lens,
                                targets,
                                rl_mode=True)
        return q_outputs[0]  # [num_sentences, max_sentence_len, vocab_size]

    def write_summary(self, t):
        metrics_to_log = [
            'epoch_q_loss', 'epoch_sampled_reward', 'epoch_kl_div',
            'epoch_logp', 'epoch_logp_logq'
        ]

        if self.config.kl_control:
            metrics_to_log.append('epoch_kl_reward')

        metrics_dict = {}
        for metric in metrics_to_log:
            met_val = getattr(self, metric, None)
            metrics_dict[metric] = met_val
            if met_val is not None:
                self.writer.update_loss(loss=met_val, step_i=t, name=metric)

        # Write pandas csv with metrics to save dir
        self.df = self.df.append(metrics_dict, ignore_index=True)
        self.df.to_csv(self.pandas_path)

    def set_up_logging(self):
        # Get save path
        time_now = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
        default_save_path = Path(
            os.path.expanduser('~') + '/dialog/model_checkpoints/rl/')

        # Folder for type of RL used
        experiment_name = self.config.experiment_name
        if experiment_name is None:
            # if self.config.double_q: experiment_name = 'double_q'
            if self.config.model_averaging:
                experiment_name = 'model_averaging'
            elif self.config.kl_control:
                experiment_name = 'kl_control'
                if self.config.kl_calc == 'sample':
                    experiment_name += '_sample'
            else:
                experiment_name = 'batch_q'
            if self.config.psi_learning:
                experiment_name += '/psi_learning'
            if self.config.monte_carlo_count > 1:
                experiment_name = 'monte_carlo_targets/' + experiment_name

        # Folder for type of rewards used
        extra_save_dir = self.config.extra_save_dir
        if not extra_save_dir:
            if len(self.config.rewards) == 1:
                extra_save_dir = self.config.rewards[0]
            else:
                extra_save_dir = 'reward_combo'

        # Folder for which model was used
        extra_model_desc = ""
        if self.config.context_input_only:
            extra_model_desc = 'input_only_'
        if self.config.emotion and 'input_only' not in extra_model_desc:
            extra_model_desc += "emotion_"
        if self.config.infersent and 'input_only' not in extra_model_desc:
            extra_model_desc += "infersent_"

        # Make save path
        self.save_dir = default_save_path.joinpath(
            self.q_net.config.data, extra_save_dir, experiment_name,
            extra_model_desc + self.q_net.config.model, time_now)

        # Make directory and save config
        print("Saving output to", self.save_dir)
        os.makedirs(self.save_dir, exist_ok=True)
        with open(os.path.join(self.save_dir, 'config.txt'), 'w') as f:
            print(self.config, file=f)

        # Make loggers
        self.writer = TensorboardWriter(self.save_dir)
        self.pandas_path = os.path.join(self.save_dir, "metrics.csv")
        self.df = pd.DataFrame()

    def save_model(self, t):
        """Save parameters to checkpoint"""
        ckpt_path = os.path.join(self.save_dir, f'q_net{t}.pkl')
        print(f'Save parameters to {ckpt_path}')
        torch.save(self.q_net.model.state_dict(), ckpt_path)

        ckpt_path = os.path.join(self.save_dir, f'target_q_net{t}.pkl')
        torch.save(self.target_q_net.model.state_dict(), ckpt_path)

    def load_models(self):
        """Load parameters from RL checkpoint"""
        # Override specific checkpoint with particular one
        base_checkpoint_dir = str(Path(self.config.checkpoint).parent)
        if self.config.rl_ckpt_epoch is not None:
            q_ckpt_path = os.path.join(
                base_checkpoint_dir,
                'q_net' + str(self.config.rl_ckpt_epoch) + '.pkl')
            target_q_ckpt_path = os.path.join(
                base_checkpoint_dir,
                'target_q_net' + str(self.config.rl_ckpt_epoch) + '.pkl')
            self.t = int(self.config.rl_ckpt_epoch)
        else:
            ckpt_file = self.config.checkpoint.replace(base_checkpoint_dir, '')
            ckpt_file = ckpt_file.replace('/', '')
            ckpt_num = ckpt_file[len('q_net'):ckpt_file.find('.')]
            q_ckpt_path = self.config.checkpoint
            target_q_ckpt_path = os.path.join(
                base_checkpoint_dir, 'target_q_net' + ckpt_num + '.pkl')
            self.t = int(ckpt_num)

        print(f'Loading parameters for Q net from {q_ckpt_path}')
        q_ckpt = torch.load(q_ckpt_path)
        q_ckpt = convert_old_checkpoint_format(q_ckpt)
        self.q_net.model.load_state_dict(q_ckpt)

        print(f'Loading parameters for target Q net from {target_q_ckpt_path}')
        target_q_ckpt = torch.load(target_q_ckpt_path)
        target_q_ckpt = convert_old_checkpoint_format(target_q_ckpt)
        self.target_q_net.model.load_state_dict(target_q_ckpt)

        # Ensure weights are initialized to be on the GPU when necessary
        if torch.cuda.is_available():
            print('Converting checkpointed model to cuda tensors')
            self.q_net.model.cuda()
            self.target_q_net.model.cuda()

    def get_data_loader(self):
        # If checkpoint is for an emotion model, load that pickle file
        emotion_sentences = None
        if self.config.emotion:
            emotion_sentences = load_pickle(self.val_config.emojis_path)

        # Load infersent embeddings if necessary
        infersent_sentences = None
        if self.config.infersent:
            print('Loading infersent sentence embeddings...')
            infersent_sentences = load_pickle(self.val_config.infersent_path)
            embedding_size = infersent_sentences[0][0].shape[0]
            self.config.infersent_output_size = embedding_size
            self.val_config.infersent_output_size = embedding_size

        return get_loader(
            sentences=load_pickle(self.val_config.sentences_path),
            conversation_length=load_pickle(
                self.val_config.conversation_length_path),
            sentence_length=load_pickle(self.val_config.sentence_length_path),
            vocab=self.vocab,
            batch_size=self.config.batch_size,
            emojis=emotion_sentences,
            infersent=infersent_sentences)

    def interact(self):
        print("Commencing interaction with bot trained with RL")
        self.q_net.interact(
            max_sentence_length=self.config.max_sentence_length,
            max_conversation_length=self.config.max_conversation_length,
            sample_by='priority',
            debug=True,
            print_history=True)
示例#2
0
class DBCQ:
    def __init__(self, prior_config, rl_config, beam_size=5):
        self.prior_config = prior_config
        self.rl_config = rl_config
        self.rl_config.beam_size = beam_size

        print('Loading Vocabulary...')
        self.vocab = Vocab()
        self.vocab.load(prior_config.word2id_path, prior_config.id2word_path)
        self.prior_config.vocab_size = self.vocab.vocab_size
        self.rl_config.vocab_size = self.vocab.vocab_size
        print(f'Vocabulary size: {self.vocab.vocab_size}')
        
        self.eval_data = self.get_data_loader()
        self.build_models()

    def build_models(self):
        rl_config = copy.deepcopy(self.rl_config)
        rl_config.checkpoint = None  

        print('Building Q network')
        if rl_config.model in VariationalModels:
            self.q_net = VariationalSolver(
                rl_config, None, self.eval_data, vocab=self.vocab, 
                is_train=False)
        else:
            self.q_net = Solver(
                rl_config, None, self.eval_data, vocab=self.vocab, 
                is_train=False)
        self.q_net.build()
        self.load_q_network()

        print('Building prior network')
        if self.prior_config.model in VariationalModels:
            self.pretrained_prior = VariationalSolver(
                self.prior_config, None, self.eval_data, vocab=self.vocab, 
                is_train=False)
        else:
            self.pretrained_prior = Solver(
                self.prior_config, None, self.eval_data, vocab=self.vocab, 
                is_train=False)
        self.pretrained_prior.build()
       
        # Freeze the weights so they stay constant
        self.pretrained_prior.model.eval()
        for params in self.pretrained_prior.model.parameters():
            params.requires_grad = False
        self.q_net.model.eval()
        for params in self.q_net.model.parameters():
            params.requires_grad = False

    def load_q_network(self):
        """Load parameters from RL checkpoint"""
        print(f'Loading parameters for Q net from {self.rl_config.checkpoint}')
        q_ckpt = torch.load(self.rl_config.checkpoint)
        q_ckpt = convert_old_checkpoint_format(q_ckpt)
        self.q_net.model.load_state_dict(q_ckpt)

        # Ensure weights are initialized to be on the GPU when necessary
        if torch.cuda.is_available():
            print('Converting checkpointed model to cuda tensors')
            self.q_net.model.cuda()

    def get_data_loader(self):
        # If checkpoint is for an emotion model, load that pickle file
        emotion_sentences = None
        if self.prior_config.emotion:
            emotion_sentences = load_pickle(self.prior_config.emojis_path)

        # Load infersent embeddings if necessary
        infersent_sentences = None
        if self.prior_config.infersent:
            print('Loading infersent sentence embeddings...')
            infersent_sentences = load_pickle(self.prior_config.infersent_path)
            embedding_size = infersent_sentences[0][0].shape[0]
            self.prior_config.infersent_output_size = embedding_size

        return get_loader(
            sentences=load_pickle(self.prior_config.sentences_path),
            conversation_length=load_pickle(
                self.prior_config.conversation_length_path),
            sentence_length=load_pickle(self.prior_config.sentence_length_path),
            vocab=self.vocab,
            batch_size=self.prior_config.batch_size,
            emojis=emotion_sentences,
            infersent=infersent_sentences)

    def interact(self, max_conversation_length=5, sample_by='priority', 
                 debug=True):
        model_name = self.prior_config.model
        context_sentences = []

        print("Time to start a conversation with the chatbot! It's name is", 
              model_name)
        username = input("What is your name? ")

        print("Let's start chatting. You can type 'quit' at any time to quit.")
        utterance = input("Input: ")
        print("\033[1A\033[K")  # Erases last line of output

        while (utterance.lower() != 'quit' and utterance.lower() != 'exit'):
            # Process utterance
            sentences = utterance.split('/')

            # Code and decode user input to show how it is transformed for model
            coded, lens = self.pretrained_prior.process_user_input(sentences)
            decoded = [self.vocab.decode(sent) for i, sent in enumerate(
                coded) if i < lens[i]]
            print(username + ':', '. '.join(decoded))

            # Append to conversation
            context_sentences.extend(sentences)

            gen_response = self.generate_response_to_input(
                context_sentences, max_conversation_length, sample_by=sample_by,
                debug=debug)

            # Append generated sentences to conversation
            context_sentences.append(gen_response)

            # Print and get next user input
            print("\n" + model_name + ": " + gen_response)
            utterance = input("Input: ")
            print("\033[1A\033[K")

    def process_raw_text_into_input(self, raw_text_sentences,
                                    max_conversation_length=5, debug=False,):
        sentences, lens = self.pretrained_prior.process_user_input(
            raw_text_sentences, self.rl_config.max_sentence_length)

        # Remove any sentences of length 0
        sentences = [sent for i, sent in enumerate(sentences) if lens[i] > 0]
        good_raw_sentences = [sent for i, sent in enumerate(
            raw_text_sentences) if lens[i] > 0]
        lens = [l for l in lens if l > 0]

        # Trim conversation to max length
        sentences = sentences[-max_conversation_length:]
        lens = lens[-max_conversation_length:]
        good_raw_sentences = good_raw_sentences[-max_conversation_length:]
        convo_length = len(sentences)

        # Convert to torch variables
        input_sentences = to_var(torch.LongTensor(sentences))
        input_sentence_length = to_var(torch.LongTensor(lens))
        input_conversation_length = to_var(torch.LongTensor([convo_length]))

        if debug:
            print('\n**Conversation history:**')
            for sent in sentences:
                print(self.vocab.decode(list(sent)))

        return (input_sentences, input_sentence_length, 
                input_conversation_length)

    def duplicate_context_for_beams(self, sentences, sent_lens, conv_lens,
                                    beams):
        conv_lens = conv_lens.repeat(len(beams))
        # [beam_size * sentences, sentence_len]
        if len(sentences) > 1:
            targets = torch.cat(
                [torch.cat([sentences[1:,:], beams[i,:].unsqueeze(0)], 0) 
                for i in range(len(beams))], 0)
        else:
            targets = beams

        # HRED
        if self.rl_config.model not in VariationalModels:
            sent_lens = sent_lens.repeat(len(beams))
            return sentences, sent_lens, conv_lens, targets
        
        # VHRED, VHCR
        new_sentences = torch.cat(
            [torch.cat([sentences, beams[i,:].unsqueeze(0)], 0) 
            for i in range(len(beams))], 0)
        new_len = to_var(torch.LongTensor([self.rl_config.max_sentence_length]))
        sent_lens = torch.cat(
            [torch.cat([sent_lens, new_len], 0) for i in range(len(beams))])
        return new_sentences, sent_lens, conv_lens, targets

    def generate_response_to_input(self, raw_text_sentences, 
                                   max_conversation_length=5,
                                   sample_by='priority', emojize=True,
                                   debug=True):
        with torch.no_grad():
            (input_sentences, input_sent_lens, 
             input_conv_lens) = self.process_raw_text_into_input(
                raw_text_sentences, debug=debug,
                max_conversation_length=max_conversation_length)

            # Initialize a tensor for beams
            beams = to_var(torch.LongTensor(
                np.ones((self.rl_config.beam_size, 
                        self.rl_config.max_sentence_length))))

            # Create a batch with the context duplicated for each beam
            (sentences, sent_lens, 
            conv_lens, targets) = self.duplicate_context_for_beams(
                input_sentences, input_sent_lens, input_conv_lens, beams)

            # Continuously feed beam sentences into networks to sample the next 
            # best word, add that to the beam, and continue
            for i in range(self.rl_config.max_sentence_length):
                # Run both models to obtain logits
                prior_output = self.pretrained_prior.model(
                    sentences, sent_lens, conv_lens, targets, rl_mode=True)
                all_prior_logits = prior_output[0]
                
                q_output = self.q_net.model(
                    sentences, sent_lens, conv_lens, targets, rl_mode=True)
                all_q_logits = q_output[0]

                # Select only those logits for next word
                q_logits = all_q_logits[:, i, :].squeeze()
                prior_logits = all_prior_logits[:, i, :].squeeze()

                # Get prior distribution for next word in each beam
                prior_dists = torch.nn.functional.softmax(prior_logits, 1)

                for b in range(self.rl_config.beam_size):
                    # Sample from the prior bcq_n times for each beam
                    dist = torch.distributions.Categorical(prior_dists[b,:])
                    sampled_idxs = dist.sample_n(self.rl_config.bcq_n)

                    # Select sample with highest q value
                    q_vals = torch.stack(
                        [q_logits[b, idx] for idx in sampled_idxs])
                    _, best_word_i = torch.max(q_vals, 0) 
                    best_word = sampled_idxs[best_word_i]

                    # Update beams
                    beams[b, i] = best_word

                (sentences, sent_lens, 
                 conv_lens, targets) = self.duplicate_context_for_beams(
                    input_sentences, input_sent_lens, input_conv_lens, beams)
            
            generated_sentences = beams.cpu().numpy()

        if debug:
            print('\n**All generated responses:**')
            for gen in generated_sentences:
                print(detokenize(self.vocab.decode(list(gen))))
        
        gen_response = self.pretrained_prior.select_best_generated_response(
            generated_sentences, sample_by, beam_size=self.rl_config.beam_size)

        decoded_response = self.vocab.decode(list(gen_response))
        decoded_response = detokenize(decoded_response)

        if emojize:
            inferred_emojis = self.pretrained_prior.botmoji.emojize_text(
                raw_text_sentences[-1], 5, 0.07)
            decoded_response = inferred_emojis + " " + decoded_response
        
        return decoded_response
示例#3
0
class Chatbot(ABC):
    def __init__(self,
                 id,
                 name,
                 checkpoint_path,
                 max_conversation_length=5,
                 max_sentence_length=30,
                 is_test_bot=False,
                 rl=False,
                 safe_mode=True):
        """
        All chatbots should extend this class and be registered with the @registerbot decorator
        :param id: An id string, must be unique!
        :param name: A user-friendly string shown to the end user to identify the chatbot. Should be unique.
        :param checkpoint_path: Directory where the trained model checkpoint is saved.
        :param max_conversation_length: Maximum number of conversation turns to condition on.
        :param max_sentence_length: Maximum number of tokens per sentence.
        :param is_test_bot: If True, this bot it can be chosen from the list of
            bots you see at /dialogadmins screen, but will never be randomly
            assigned to users landing on the home page.
        """
        self.id = id
        self.name = name
        self.checkpoint_path = checkpoint_path
        self.max_conversation_length = max_conversation_length
        self.max_sentence_length = max_sentence_length
        self.is_test_bot = is_test_bot
        self.safe_mode = safe_mode

        print("\n\nCreating chatbot", name)

        self.config = get_config_from_dir(checkpoint_path,
                                          mode='test',
                                          load_rl_ckpt=rl)
        self.config.beam_size = 5

        print('Loading Vocabulary...')
        self.vocab = Vocab()
        self.vocab.load(self.config.word2id_path, self.config.id2word_path)
        print(f'Vocabulary size: {self.vocab.vocab_size}')

        self.config.vocab_size = self.vocab.vocab_size

        # If checkpoint is for an emotion model, load that pickle file
        emotion_sentences = None
        if self.config.emotion:
            emotion_sentences = load_pickle(self.config.emojis_path)

        # Load infersent embeddings if necessary
        infersent_sentences = None
        if self.config.infersent:
            print('Loading infersent sentence embeddings...')
            infersent_sentences = load_pickle(self.config.infersent_path)
            embedding_size = infersent_sentences[0][0].shape[0]
            self.config.infersent_output_size = embedding_size

        self.data_loader = get_loader(
            sentences=load_pickle(self.config.sentences_path),
            conversation_length=load_pickle(
                self.config.conversation_length_path),
            sentence_length=load_pickle(self.config.sentence_length_path),
            vocab=self.vocab,
            batch_size=self.config.batch_size,
            emojis=emotion_sentences)

        if self.config.model in VariationalModels:
            self.solver = VariationalSolver(self.config,
                                            None,
                                            self.data_loader,
                                            vocab=self.vocab,
                                            is_train=False)
        elif self.config.model == 'Transformer':
            self.solver = ParlAISolver(self.config)
        else:
            self.solver = Solver(self.config,
                                 None,
                                 self.data_loader,
                                 vocab=self.vocab,
                                 is_train=False)

        self.solver.build()

    def handle_messages(self, messages):
        """
        Takes a list of messages, and combines those with magic to return a response string
        :param messages: list of strings
        :return: string
        """
        greetings = [
            "hey , how are you ?", "hi , how 's it going ?",
            "hey , what 's up ?", "hi . how are you ?",
            "hello , how are you doing today ? ",
            "hello . how are things with you ?",
            "hey ! so, tell me about yourself .", "hi . nice to meet you ."
        ]

        # Check for no response
        if len(messages) == 0:
            # Respond with canned greeting response
            return np.random.choice(greetings)

        # Check for overly short intro messages
        if len(messages) < 2 and len(messages[0]) <= 6:  # 6 for "hello."
            first_m = messages[0].lower()
            if 'hi' in first_m or 'hey' in first_m or 'hello' in first_m:
                # Respond with canned greeting response
                return np.random.choice(greetings)

        response = self.solver.generate_response_to_input(
            messages,
            max_conversation_length=self.max_conversation_length,
            emojize=True,
            debug=False)

        # Manually remove inappropriate language from response.
        # WARNING: the following code contains inappropriate language
        if self.safe_mode:
            response = response.replace("f*g", "<unknown>")
            response = response.replace("gays", "<unknown>")
            response = response.replace("c**t", "%@#$")
            response = response.replace("f**k", "%@#$")
            response = response.replace("shit", "%@#$")
            response = response.replace("dyke", "%@#$")
            response = response.replace("hell", "heck")
            response = response.replace("dick", "d***")
            response = response.replace("bitch", "%@#$")

        return response
class Chatbot(ABC):
    def _init_(self, id, name, checkpoint_path, max_conversation_lenght = 5,
               max_sentence_lenght = 30, is_test_bot = False, rl = False,
               safe_mode = True):
        """
        All chatbots should extend this class and be registered with the @registerbot decorator
        :param id: An id string, must be unique!
        :param name: a user-friendly string shown to the end user to identify the chatbot. Should be unique.
        :param checkpoint_path: Directory where the trained model checkpoint is saved.
        :param max_conversation_length: Maximum number of conversation turns to condition on.
        :param max_sentence_length: Maximum number of tokens per sentence.
        :param is_test_bot: If True, this bot it can be chosen from the list of bots
        you see at /dialogadmins screen, but will never be randmly
        assigned to users landing on the home oage.
        """
        self.id = id
        self.name = name
        self.checkpoint_path = checkpoint_path
        self.max_conversation_length = max_conversation_length
        self.max_sentence_length = max.sentence_length
        self.is_teste_bot = is_test_bot
        self.safe_mode = safe_mode
        
        print("\n\nCretaing Chatbot", name)
        
        self.config = get_config_from_dir(checkpoint_path, mode = 'test',
                                          load_rl_ckpt =  rl)
        
        self.config.beam_size = 5
        
        print ("Carregando vocabulario")
        self.vocab = Vocab()
        self.vocab.load(self.config.word2id_path, self.config.id2word_path)
        print(f'Vocabulary size: {self.vocab.vocab_size}')
        
        self.config.vocab_size = sel.vocab.vocab_size
        
        #if checkpoint is for an emotion model, load that pickle file 
        emotion_sentences = None 
        if self.config.emotion:
            emotion_sentences = load_pickle(self.config>emojis_path)
            
        
        #load inferest embeddings if necessary
        infersent_sentences = None
        if self.config.infersent:
            print("Loading infersent sentences embeddings...")
            infersent_sentences = load_pickles(self.config.infersent_path)
            embedding_size = infersent_sentences [0] [0].shape[0]
            self.config.infersent_output_size = embedding_size
            
        
        self.data_loader = get_loader(
                sentences = load_pickle(self.config.sentences_path),
                conversation_length = load_pickle(self.config.conversation_length_path),
                sentence_length = load_pickle(self.config.sentence_legth_path),
                vocab = self.vocab,
                batch_size = self.config.batch_size,
                emojis = emotion_sentences)
        
        if self.config.model in VariationalModels:
            self.solver = VariationalSolver(self.config, None, self.data_loader,
                                            vocab = self.vocab, is_train = False)
            
        elif self.config.model == "Transformar":
            self.solver = ParlAISolver(self.config)
        else:
            self.solver = Solver(self.config, None, self.data_loader,
                                 vocab = self.vocab, is_train = False)
        self.solver.build()