Exemple #1
0
        sentence_length=load_pickle(config.sentence_length_path),
        vocab=vocab,
        batch_size=config.batch_size,
        emojis=emotion_sentences,
        infersent=infersent_sentences)

    if config.model in VariationalModels:
        solver = VariationalSolver(config,
                                   None,
                                   data_loader,
                                   vocab=vocab,
                                   is_train=False)
    else:
        solver = Solver(config, None, data_loader, vocab=vocab, is_train=False)

    solver.build()

    self_play_buffer = pd.DataFrame()

    for i in range(kwargs.sample_conversations):
        messages = solver.self_play(
            conversation_length=kwargs.conversation_length,
            max_sentence_length=kwargs.max_sentence_length,
            max_conversation_length=kwargs.max_conversation_length,
            sample_by=kwargs.sample_by)
        responses = messages[1:] + ['']
        conv_df = pd.DataFrame()
        conv_df['Message'] = messages
        conv_df['Response'] = responses
        conv_df['Response Rating'] = 0
        conv_df['ID'] = len(self_play_buffer) + np.arange(len(messages))
Exemple #2
0
class REINFORCETuner:
    def __init__(self, config, val_config):
        self.config = config
        self.val_config = val_config

        vocab = Vocab()
        vocab.load(config.word2id_path, config.id2word_path)
        self.vocab = vocab
        self.config.vocab_size = vocab.vocab_size

        # To initialize simulated conversations
        self.start_sentences = self.load_sentences(self.config.dataset_dir)
        self.eval_data = self.get_data_loader(train=False)
        self.build_models()

        if self.config.load_rl_ckpt:
            self.load_models()

        self.set_up_optimizers()
        self.set_up_summary()
        self.set_up_logging()

        if self.config.rl_batch_size == self.config.beam_size:
            raise ValueError('Decoding breaks if batch_size == beam_size')

    def build_models(self):
        config = deepcopy(self.config)

        # If loading RL checkpoint, don't try to load the ckpt through Solver
        if self.config.load_rl_ckpt:
            config.checkpoint = None

        if self.config.model != 'VHRED':
            raise ValueError("Only VHRED currently supported")

        print('Building policy network...')
        self.policy_net = VariationalSolver(config,
                                            None,
                                            self.eval_data,
                                            vocab=self.vocab,
                                            is_train=False)
        self.policy_net.build()

        print('Building simulator network...')
        self.simulator_net = VariationalSolver(config,
                                               None,
                                               self.eval_data,
                                               vocab=self.vocab,
                                               is_train=False)
        self.simulator_net.build()
        self.simulator_net.model.eval()

        print('Successfully initialized policy and simulator networks')

    def set_up_optimizers(self):
        self.optimizers = {}
        named_params = list(self.policy_net.model.named_parameters())

        if self.config.vhrl or self.config.reinforce:
            manager_worker_params = [p for name, p in named_params]
            self.optimizers['worker_manager'] = optim.Adam(
                manager_worker_params, lr=self.config.learning_rate)
        elif self.config.decoupled_vhrl:
            # No gradients flow from worker to manager if decoupled
            manager_params = [
                p for name, p in named_params if 'decoder' not in name
            ]
            worker_params = [
                p for name, p in named_params if 'decoder' in name
            ]
            self.optimizers['manager'] = optim.Adam(
                manager_params, lr=self.config.manager_learning_rate)
            self.optimizers['worker'] = optim.Adam(
                worker_params, lr=self.config.worker_learning_rate)

    def train(self):
        """Function to initiate RL training loop.
        """
        if self.config.vhrl:
            print(
                '**Starting VHRL training!**\n'
                f'Will make {self.config.num_steps} joint manager-worker updates'
            )
        elif self.config.reinforce:
            print('**Starting REINFORCE training!**\n'
                  f'Will make {self.config.num_steps} worker updates')
        elif self.config.decoupled_vhrl:
            print(
                '**Starting DECOUPLED VHRL training!**\n'
                f'Will make {self.config.num_steps} alternating manager-worker updates'
            )
        else:
            raise ValueError(
                'Training mode not understood. '
                'Choose from --vhrl | --reinforce | --decoupled_vhrl')
        print('... \n ... \n ...')

        # Starting RL training loop
        while self.step <= self.config.num_steps:
            self.step += 1
            self.train_step()

            if self.step % self.config.print_every_n == 0:
                self.print_summary()

            if self.step % self.config.log_every_n == 0:
                self.write_summary(self.step)

            if self.step % self.config.save_every_n == 0:
                self.save_model(self.step)

    def train_step(self):
        """RL training step.
        Behavior depends on mode --vhrl | --reinforce | --decoupled_vhrl
        """
        conversations, manager_actions, worker_actions = self.run_episodes()
        rewards = self.compute_rewards(conversations, self.config.rewards,
                                       self.config.reward_weights,
                                       self.config.gamma)

        # Collect some logging info
        self.manager_actions_history.append(manager_actions.mean().item())
        self.worker_actions_history.append(worker_actions.mean().item())
        response_lens = [
            len(resp.split()) for conv in conversations for resp in conv[1::2]
        ]
        self.response_len = np.mean(response_lens)
        self.recent_dialog = conversations[0]

        if self.config.vhrl or self.config.reinforce:
            optimizer = self.optimizers['worker_manager']
            alpha, beta = self.config.alpha, self.config.beta

        elif self.config.decoupled_vhrl:
            # Update manager on even steps, update worker on odd steps
            if self.step % 2 == 0:
                update_turn = 'manager'
                alpha, beta = self.config.alpha, 0
            else:
                update_turn = 'worker'
                alpha, beta = 0, self.config.beta

            # Reset optimizer moving averages as model parametrs changed
            # by different optimizer in previous step
            optimizer = self.optimizers[update_turn]
            optimizer.state = defaultdict(dict)

        # We do not distinguish between manager and worker level rewards
        # but this could be added by calculating two different sets of rewards
        loss = (-1 * (alpha * manager_actions + beta * worker_actions) *
                rewards).mean()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.policy_net.model.parameters(),
                                        self.config.gradient_clip)
        optimizer.step()

    def run_episodes(self):
        """Simulate conversations which are interactions for RL training.

        Note we use s+l-1 indexing since the conversations are flattened into
        sentences. To find the output corresponding to a certain conversation
        we need to find where its sentences start (at index s-1), then we find the
        last sentence in that conversation (at l additional steps). We extract
        the output corresponding to that sentence.
        """
        # Initialize interactions with sentence from train set
        conversations = deepcopy(
            random.sample(self.start_sentences, self.config.rl_batch_size))
        manager_actions = []
        worker_actions = []

        simulated_turn = False
        # episode_len actions implies 2 * episode_len + 1 turns
        for turn in range(2 * self.config.episode_len):
            batch = self.policy_net.batchify(conversations)
            sentences = batch[0]
            sentence_length = batch[1]
            conversation_length = batch[2]

            net = self.simulator_net if simulated_turn else self.policy_net
            output = net.model(sentences,
                               sentence_length,
                               conversation_length, [],
                               rl_mode=True,
                               decode=True)

            # Index corresponding to the start of each conversation
            start_idx = np.cumsum([0] + conversation_length.tolist()[:-1])
            responses = [
                output[0][s + l - 1].tolist()
                for s, l in zip(start_idx, conversation_length.tolist())
            ]
            decoded = [
                self.policy_net.vocab.decode(resp) for resp in responses
            ]
            _ = [
                conv.append(decoded[i]) for i, conv in enumerate(conversations)
            ]

            if not simulated_turn:
                # Get worker actions
                response_len = self.get_response_len(responses)
                word_probs = [
                    output[6][s + l - 1]
                    for s, l in zip(start_idx, conversation_length.tolist())
                ]
                log_p_words = [
                    torch.sum(torch.log(word_probs[i][:l]))
                    for i, l in enumerate(response_len)
                ]

                # Divide by len to eliminate preference for longer responses
                log_p_words = torch.stack(
                    log_p_words) / torch.cuda.FloatTensor(response_len)
                worker_actions.append(log_p_words)

                # Get manager actions
                log_p_z = torch.stack([
                    output[7][s + l - 1, ]
                    for s, l in zip(start_idx, conversation_length.tolist())
                ])
                manager_actions.append(log_p_z)

            # Switch speaker
            simulated_turn = not simulated_turn

        return conversations, torch.stack(manager_actions,
                                          dim=1), torch.stack(worker_actions,
                                                              dim=1)

    def compute_rewards(self,
                        conversations,
                        rewards_lst,
                        reward_weights,
                        gamma=0.0):
        supported = {
            'reward_question', 'reward_you', 'reward_toxicity',
            'reward_bot_deepmoji', 'reward_user_deepmoji',
            'reward_conversation_repetition', 'reward_utterance_repetition',
            'reward_infersent_coherence', 'reward_deepmoji_coherence',
            'reward_word2vec_coherence', 'reward_bot_response_length',
            'reward_word_similarity', 'reward_USE_similarity'
        }

        episode_len = self.config.episode_len
        num_convs = self.config.rl_batch_size
        combined_rewards = np.zeros((num_convs, episode_len))

        for r, w in zip(rewards_lst, reward_weights):
            if r not in supported: raise NotImplementedError()
            reward_func = getattr(hrl_rewards, r)
            rewards = reward_func(conversations)
            discounted = discount(rewards, gamma)
            normalized = normalizeZ(discounted)
            combined_rewards += float(w) * normalized

            self.rewards_history[r].append(rewards.mean().item())

        # [num_convs, num_actions] = [rl_batch_size, episode_len]
        return to_var(torch.FloatTensor(combined_rewards))

    def get_response_len(self, responses):
        lens = []
        for resp in responses:
            try:
                lens.append(resp.index(EOS_ID) + 1)
            except ValueError:
                lens.append(self.config.max_unroll)
        return lens

    def set_up_summary(self):
        # Initialize counters and summaries
        self.step = 0
        self.manager_actions_history = [0]
        self.worker_actions_history = [0]
        self.rewards_history = {}
        for r in self.config.rewards:
            self.rewards_history[r] = [0]

    def print_summary(self):
        print(f'Summary at update step {self.step}')
        for r in self.config.rewards:
            print(r + ': ')
            print('\t Batch: {:.3f} \n'
                  '\t Running: {:.3f}'.format(self.rewards_history[r][-1],
                                              np.mean(
                                                  self.rewards_history[r])))
            print(10 * '-')

        print('log utterance probability (worker actions):')
        print('\t Batch: {:.3f} \n'
              '\t Running: {:.3f}'.format(self.worker_actions_history[-1],
                                          np.mean(
                                              self.worker_actions_history)))
        print(10 * '-')
        print('log PDF of z (manager actions):')
        print('\t Batch: {:.3f} \n'
              '\t Running: {:.3f}'.format(
                  self.manager_actions_history[-1],
                  np.mean(self.manager_actions_history)))
        print(10 * '-')
        print('Response length:')
        print('\t Batch: {:.0f}'.format(self.response_len))
        print(10 * '-')
        print('Sample dialog:')
        print(self.recent_dialog)
        print(20 * '*')
        print(20 * '*')

        sys.stdout.flush()

    def write_summary(self, t):
        metrics_dict = {
            'step': self.step,
            'response_len': self.response_len,
            'sample_dialog': self.recent_dialog,
            'batch_manager_action': self.manager_actions_history[-1],
            'running_manager_actions': np.mean(self.manager_actions_history),
            'batch_worker_action': self.worker_actions_history[-1],
            'running_worker_actions': np.mean(self.worker_actions_history)
        }

        for r in self.config.rewards:
            metrics_dict['batch-' + r] = self.rewards_history[r][-1]
            metrics_dict['running-' + r] = np.mean(self.rewards_history[r])

        for metric, val in metrics_dict.items():
            if metric not in {'step', 'sample_dialog'}:
                self.writer.update_loss(loss=val, step_i=t, name=metric)

        # Write pandas csv with metrics to save dir
        self.df = self.df.append(metrics_dict, ignore_index=True)
        self.df.to_csv(self.pandas_path)

    def set_up_logging(self):
        # Get save path
        time_now = datetime.now().strftime('%Y-%m-%d_%H;%M;%S')
        default_save_path = Path('model_checkpoints/rl/')

        # Folder for type of RL used
        if self.config.reinforce:
            rl_algorithm = 'reinforce'
        elif self.config.vhrl:
            rl_algorithm = 'vhrl'
        elif self.config.decoupled_vhrl:
            rl_algorithm = 'decoupled_vhrl'

        # Folder for type of rewards used
        extra_save_dir = self.config.extra_save_dir
        if not extra_save_dir:
            if len(self.config.rewards) == 1:
                extra_save_dir = self.config.rewards[0]
            else:
                reward_names = [
                    r[len('reward_'):] for r in self.config.rewards
                ]
                extra_save_dir = 'reward_' + '-'.join(reward_names)

        # Make save path
        self.save_dir = default_save_path.joinpath(
            self.config.data, rl_algorithm, extra_save_dir,
            self.policy_net.config.model, time_now)

        # Make directory and save config
        print("Saving output to", self.save_dir)
        os.makedirs(self.save_dir, exist_ok=True)
        with open(os.path.join(self.save_dir, 'config.txt'), 'w') as f:
            print(self.config, file=f)

        # Make loggers
        self.writer = TensorboardWriter(self.save_dir)
        self.pandas_path = os.path.join(self.save_dir, "metrics.csv")
        self.df = pd.DataFrame()

    def save_model(self, t):
        """Save parameters to checkpoint"""
        ckpt_path = os.path.join(self.save_dir, f'policy_net{t}.pkl')

        print('%' * 5)
        print('%' * 5)
        print(f'Saving parameters to {ckpt_path}')
        print('%' * 5)
        print('%' * 5)

        torch.save(self.policy_net.model.state_dict(), ckpt_path)

    def load_models(self):
        """Load parameters from RL checkpoint"""
        # Override specific checkpoint with particular one
        base_checkpoint_dir = str(Path(self.config.checkpoint).parent)
        if self.config.rl_ckpt_epoch is not None:
            policy_ckpt_path = os.path.join(
                base_checkpoint_dir,
                'policy_net' + str(self.config.rl_ckpt_epoch) + '.pkl')
            self.iter = int(self.config.rl_ckpt_epoch)
        else:
            ckpt_file = self.config.checkpoint.replace(base_checkpoint_dir, '')
            ckpt_file = ckpt_file.replace('/', '')
            ckpt_num = ckpt_file[len('policy_net'):ckpt_file.find('.')]
            policy_ckpt_path = self.config.checkpoint
            self.iter = int(ckpt_num)

        print(f'Loading parameters for policy net from {policy_ckpt_path}')
        policy_ckpt = torch.load(policy_ckpt_path)
        policy_ckpt = convert_old_checkpoint_format(policy_ckpt)
        self.policy_net.model.load_state_dict(policy_ckpt)

        # Ensure weights are initialized to be on the GPU when necessary
        if torch.cuda.is_available():
            print('Converting checkpointed model to cuda tensors')
            self.policy_net.model.cuda()

    def get_data_loader(self, train=True):
        if train:
            sentences_path = self.config.sentences_path
            conversation_length_path = self.config.conversation_length_path
            sentence_length_path = self.config.sentence_length_path
            batch_size = self.config.rl_batch_size

        else:
            sentences_path = self.val_config.sentences_path
            conversation_length_path = self.val_config.conversation_length_path
            sentence_length_path = self.val_config.sentence_length_path
            batch_size = self.config.batch_size

        return get_loader(
            sentences=load_pickle(sentences_path),
            conversation_length=load_pickle(conversation_length_path),
            sentence_length=load_pickle(sentence_length_path),
            vocab=self.vocab,
            batch_size=batch_size)

    def load_sentences(self, data_dir):
        """Function that loads start sentences from train data.
        Used for starting simulated conversations.
        """
        sent_dir = data_dir.joinpath('train/raw_sentences.pkl')
        conversations = pickle.load(open(sent_dir, 'rb'))
        sentences = [[conv[0]] for conv in conversations
                     if len(conv[0].split()) > 5 and len(conv[0].split()) < 20]
        return sentences

    def interact(self):
        print("Commencing interaction with bot trained with RL")
        self.policy_net.interact(
            max_sentence_length=self.config.max_sentence_length,
            max_conversation_length=self.config.max_conversation_length,
            sample_by='priority',
            debug=False,
            print_history=True)