Ejemplo n.º 1
0
    def __init__(self):
        self.id = 'rl_dbcq'
        self.name = 'RL - DBCQ'
        self.is_test_bot = False
        self.max_conversation_length = 5
        self.safe_mode = True

        kwargs_dict = {'bcq_n': 5, 'mode': 'test'}
        prior_config = get_config_from_dir(
            base_path + "reddit_casual/infersent_vhred_big_noworddrop",
            **kwargs_dict)

        kwargs_dict['load_rl_ckpt'] = True
        rl_config = get_config_from_dir(
            base_path + "rl/techniques/batch_q_mc5", **kwargs_dict)

        self.solver = dbcq.DBCQ(prior_config, rl_config)
Ejemplo n.º 2
0
import pickle
from model.models import VariationalModels


def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', type=str, default=None)
    parser.add_argument('--mode', type=str, default='test')  # or valid
    kwargs = parser.parse_args()

    config = get_config_from_dir(kwargs.checkpoint, mode=kwargs.mode)
    print(config)

    print('Loading Vocabulary...')
    vocab = Vocab()
    vocab.load(config.word2id_path, config.id2word_path)
    print(f'Vocabulary size: {vocab.vocab_size}')

    config.vocab_size = vocab.vocab_size

    emotion_sentences = None
    if config.emotion:
        emotion_sentences = load_pickle(config.emojis_path)

    # Load infersent embeddings if necessary
    infersent_sentences = None
Ejemplo n.º 3
0
        "Type of sampling for generated responses. Can be None, 'priority', or 'length'"
    )

    parser.add_argument('--conversation_length', type=int, default=10)
    parser.add_argument('--sample_conversations', type=int, default=100)

    kwargs = parser.parse_args()
    kw_dict = vars(kwargs)

    checkpoint_pieces = kwargs.checkpoint.split('/')
    if len(checkpoint_pieces) < 2:
        print('Checkpoint name does not follow expected format.')
    bot_checkpoint_name = checkpoint_pieces[-2] + '/' + checkpoint_pieces[-1]
    (cur_bot_id, cur_bot_ind) = BOT_DICT[bot_checkpoint_name]

    config = get_config_from_dir(kwargs.checkpoint, **kw_dict)
    config.beam_size = 5

    print('Loading Vocabulary...')
    vocab = Vocab()
    vocab.load(config.word2id_path, config.id2word_path)
    print(f'Vocabulary size: {vocab.vocab_size}')

    config.vocab_size = vocab.vocab_size

    # If checkpoint is for an emotion model, load that pickle file
    emotion_sentences = None
    if config.emotion:
        emotion_sentences = load_pickle(config.emojis_path)

    # Load infersent embeddings if necessary
Ejemplo n.º 4
0
    # Default rewards if user is too lazy to provide
    if not kwargs_dict['rewards']:
        kwargs_dict['rewards'] = ['reward_conversation_length']
    if not kwargs_dict['reward_weights']:
        kwargs_dict['reward_weights'] = [1.0] * len(kwargs_dict['rewards'])

    # Only one param necessary to invoke model averaging
    if kwargs_dict['model_averaging']:
        kwargs_dict['kl_control'] = True
        kwargs_dict['kl_calc'] = 'sample'

    if kwargs_dict['rl_mode'] == 'interact':
        kwargs_dict['beam_size'] = 5

    # Train config
    kwargs_dict['mode'] = 'train'
    config = get_config_from_dir(kwargs_dict['checkpoint'], **kwargs_dict)

    # Val config
    kwargs_dict['mode'] = 'valid'
    val_config = get_config_from_dir(kwargs_dict['checkpoint'], **kwargs_dict)

    bqt = BatchQ(config, val_config=val_config)

    if config.rl_mode == 'train':
        bqt.q_learn()
    elif config.rl_mode == 'interact':
        bqt.interact()
    else:
        print("Error, can't understand mode", config.mode)
Ejemplo n.º 5
0
    def __init__(self,
                 id,
                 name,
                 checkpoint_path,
                 max_conversation_length=5,
                 max_sentence_length=30,
                 is_test_bot=False,
                 rl=False,
                 safe_mode=True):
        """
        All chatbots should extend this class and be registered with the @registerbot decorator
        :param id: An id string, must be unique!
        :param name: A user-friendly string shown to the end user to identify the chatbot. Should be unique.
        :param checkpoint_path: Directory where the trained model checkpoint is saved.
        :param max_conversation_length: Maximum number of conversation turns to condition on.
        :param max_sentence_length: Maximum number of tokens per sentence.
        :param is_test_bot: If True, this bot it can be chosen from the list of
            bots you see at /dialogadmins screen, but will never be randomly
            assigned to users landing on the home page.
        """
        self.id = id
        self.name = name
        self.checkpoint_path = checkpoint_path
        self.max_conversation_length = max_conversation_length
        self.max_sentence_length = max_sentence_length
        self.is_test_bot = is_test_bot
        self.safe_mode = safe_mode

        print("\n\nCreating chatbot", name)

        self.config = get_config_from_dir(checkpoint_path,
                                          mode='test',
                                          load_rl_ckpt=rl)
        self.config.beam_size = 5

        print('Loading Vocabulary...')
        self.vocab = Vocab()
        self.vocab.load(self.config.word2id_path, self.config.id2word_path)
        print(f'Vocabulary size: {self.vocab.vocab_size}')

        self.config.vocab_size = self.vocab.vocab_size

        # If checkpoint is for an emotion model, load that pickle file
        emotion_sentences = None
        if self.config.emotion:
            emotion_sentences = load_pickle(self.config.emojis_path)

        # Load infersent embeddings if necessary
        infersent_sentences = None
        if self.config.infersent:
            print('Loading infersent sentence embeddings...')
            infersent_sentences = load_pickle(self.config.infersent_path)
            embedding_size = infersent_sentences[0][0].shape[0]
            self.config.infersent_output_size = embedding_size

        self.data_loader = get_loader(
            sentences=load_pickle(self.config.sentences_path),
            conversation_length=load_pickle(
                self.config.conversation_length_path),
            sentence_length=load_pickle(self.config.sentence_length_path),
            vocab=self.vocab,
            batch_size=self.config.batch_size,
            emojis=emotion_sentences)

        if self.config.model in VariationalModels:
            self.solver = VariationalSolver(self.config,
                                            None,
                                            self.data_loader,
                                            vocab=self.vocab,
                                            is_train=False)
        elif self.config.model == 'Transformer':
            self.solver = ParlAISolver(self.config)
        else:
            self.solver = Solver(self.config,
                                 None,
                                 self.data_loader,
                                 vocab=self.vocab,
                                 is_train=False)

        self.solver.build()
Ejemplo n.º 6
0
        if debug:
            print('\n**All generated responses:**')
            for gen in generated_sentences:
                print(detokenize(self.vocab.decode(list(gen))))
        
        gen_response = self.pretrained_prior.select_best_generated_response(
            generated_sentences, sample_by, beam_size=self.rl_config.beam_size)

        decoded_response = self.vocab.decode(list(gen_response))
        decoded_response = detokenize(decoded_response)

        if emojize:
            inferred_emojis = self.pretrained_prior.botmoji.emojize_text(
                raw_text_sentences[-1], 5, 0.07)
            decoded_response = inferred_emojis + " " + decoded_response
        
        return decoded_response


if __name__ == '__main__':

    kwargs_dict = parse_config_args()

    kwargs_dict['mode'] = 'valid'
    prior_config = get_config_from_dir(kwargs_dict['checkpoint'], **kwargs_dict)

    kwargs_dict['load_rl_ckpt'] = True
    rl_config = get_config_from_dir(kwargs_dict['q_checkpoint'], **kwargs_dict)

    dbcq = DBCQ(prior_config, rl_config)
    dbcq.interact()