예제 #1
0
    def get_data_loader(self, train=True):
        if train:
            sentences_path = self.config.sentences_path
            conversation_length_path = self.config.conversation_length_path
            sentence_length_path = self.config.sentence_length_path
            batch_size = self.config.rl_batch_size

        else:
            sentences_path = self.val_config.sentences_path
            conversation_length_path = self.val_config.conversation_length_path
            sentence_length_path = self.val_config.sentence_length_path
            batch_size = self.config.batch_size

        return get_loader(
            sentences=load_pickle(sentences_path),
            conversation_length=load_pickle(conversation_length_path),
            sentence_length=load_pickle(sentence_length_path),
            vocab=self.vocab,
            batch_size=batch_size)
예제 #2
0
    def get_data_loader(self):
        # If checkpoint is for an emotion model, load that pickle file
        emotion_sentences = None
        if self.prior_config.emotion:
            emotion_sentences = load_pickle(self.prior_config.emojis_path)

        # Load infersent embeddings if necessary
        infersent_sentences = None
        if self.prior_config.infersent:
            print('Loading infersent sentence embeddings...')
            infersent_sentences = load_pickle(self.prior_config.infersent_path)
            embedding_size = infersent_sentences[0][0].shape[0]
            self.prior_config.infersent_output_size = embedding_size

        return get_loader(
            sentences=load_pickle(self.prior_config.sentences_path),
            conversation_length=load_pickle(
                self.prior_config.conversation_length_path),
            sentence_length=load_pickle(self.prior_config.sentence_length_path),
            vocab=self.vocab,
            batch_size=self.prior_config.batch_size,
            emojis=emotion_sentences,
            infersent=infersent_sentences)
예제 #3
0
    emotion_sentences = None
    if config.emotion:
        emotion_sentences = load_pickle(config.emojis_path)

    # Load infersent embeddings if necessary
    infersent_sentences = None
    if config.infersent:
        print('Loading infersent sentence embeddings...')
        infersent_sentences = load_pickle(config.infersent_path)
        embedding_size = infersent_sentences[0][0].shape[0]
        config.infersent_output_size = embedding_size

    data_loader = get_loader(
        sentences=load_pickle(config.sentences_path),
        conversation_length=load_pickle(config.conversation_length_path),
        sentence_length=load_pickle(config.sentence_length_path),
        vocab=vocab,
        batch_size=config.batch_size,
        emojis=emotion_sentences,
        infersent=infersent_sentences)

    if config.model in VariationalModels:
        solver = VariationalSolver(config,
                                   None,
                                   data_loader,
                                   vocab=vocab,
                                   is_train=False)
    else:
        solver = Solver(config, None, data_loader, vocab=vocab, is_train=False)

    solver.build()
예제 #4
0
    def __init__(self,
                 id,
                 name,
                 checkpoint_path,
                 max_conversation_length=5,
                 max_sentence_length=30,
                 is_test_bot=False,
                 rl=False,
                 safe_mode=True):
        """
        All chatbots should extend this class and be registered with the @registerbot decorator
        :param id: An id string, must be unique!
        :param name: A user-friendly string shown to the end user to identify the chatbot. Should be unique.
        :param checkpoint_path: Directory where the trained model checkpoint is saved.
        :param max_conversation_length: Maximum number of conversation turns to condition on.
        :param max_sentence_length: Maximum number of tokens per sentence.
        :param is_test_bot: If True, this bot it can be chosen from the list of
            bots you see at /dialogadmins screen, but will never be randomly
            assigned to users landing on the home page.
        """
        self.id = id
        self.name = name
        self.checkpoint_path = checkpoint_path
        self.max_conversation_length = max_conversation_length
        self.max_sentence_length = max_sentence_length
        self.is_test_bot = is_test_bot
        self.safe_mode = safe_mode

        print("\n\nCreating chatbot", name)

        self.config = get_config_from_dir(checkpoint_path,
                                          mode='test',
                                          load_rl_ckpt=rl)
        self.config.beam_size = 5

        print('Loading Vocabulary...')
        self.vocab = Vocab()
        self.vocab.load(self.config.word2id_path, self.config.id2word_path)
        print(f'Vocabulary size: {self.vocab.vocab_size}')

        self.config.vocab_size = self.vocab.vocab_size

        # If checkpoint is for an emotion model, load that pickle file
        emotion_sentences = None
        if self.config.emotion:
            emotion_sentences = load_pickle(self.config.emojis_path)

        # Load infersent embeddings if necessary
        infersent_sentences = None
        if self.config.infersent:
            print('Loading infersent sentence embeddings...')
            infersent_sentences = load_pickle(self.config.infersent_path)
            embedding_size = infersent_sentences[0][0].shape[0]
            self.config.infersent_output_size = embedding_size

        self.data_loader = get_loader(
            sentences=load_pickle(self.config.sentences_path),
            conversation_length=load_pickle(
                self.config.conversation_length_path),
            sentence_length=load_pickle(self.config.sentence_length_path),
            vocab=self.vocab,
            batch_size=self.config.batch_size,
            emojis=emotion_sentences)

        if self.config.model in VariationalModels:
            self.solver = VariationalSolver(self.config,
                                            None,
                                            self.data_loader,
                                            vocab=self.vocab,
                                            is_train=False)
        elif self.config.model == 'Transformer':
            self.solver = ParlAISolver(self.config)
        else:
            self.solver = Solver(self.config,
                                 None,
                                 self.data_loader,
                                 vocab=self.vocab,
                                 is_train=False)

        self.solver.build()