コード例 #1
0
ファイル: hrl_tune.py プロジェクト: stjordanis/neural_chat
    def __init__(self, config, val_config):
        self.config = config
        self.val_config = val_config

        vocab = Vocab()
        vocab.load(config.word2id_path, config.id2word_path)
        self.vocab = vocab
        self.config.vocab_size = vocab.vocab_size

        # To initialize simulated conversations
        self.start_sentences = self.load_sentences(self.config.dataset_dir)
        self.eval_data = self.get_data_loader(train=False)
        self.build_models()

        if self.config.load_rl_ckpt:
            self.load_models()

        self.set_up_optimizers()
        self.set_up_summary()
        self.set_up_logging()

        if self.config.rl_batch_size == self.config.beam_size:
            raise ValueError('Decoding breaks if batch_size == beam_size')
コード例 #2
0
    with open(path, 'rb') as f:
        return pickle.load(f)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', type=str, default=None)
    parser.add_argument('--mode', type=str, default='test')  # or valid
    kwargs = parser.parse_args()

    config = get_config_from_dir(kwargs.checkpoint, mode=kwargs.mode)
    print(config)

    print('Loading Vocabulary...')
    vocab = Vocab()
    vocab.load(config.word2id_path, config.id2word_path)
    print(f'Vocabulary size: {vocab.vocab_size}')

    config.vocab_size = vocab.vocab_size

    emotion_sentences = None
    if config.emotion:
        emotion_sentences = load_pickle(config.emojis_path)

    # Load infersent embeddings if necessary
    infersent_sentences = None
    if config.infersent:
        print('Loading infersent sentence embeddings...')
        infersent_sentences = load_pickle(config.infersent_path)
        embedding_size = infersent_sentences[0][0].shape[0]
        config.infersent_output_size = embedding_size
コード例 #3
0
class Chatbot(ABC):
    def __init__(self,
                 id,
                 name,
                 checkpoint_path,
                 max_conversation_length=5,
                 max_sentence_length=30,
                 is_test_bot=False,
                 rl=False,
                 safe_mode=True):
        """
        All chatbots should extend this class and be registered with the @registerbot decorator
        :param id: An id string, must be unique!
        :param name: A user-friendly string shown to the end user to identify the chatbot. Should be unique.
        :param checkpoint_path: Directory where the trained model checkpoint is saved.
        :param max_conversation_length: Maximum number of conversation turns to condition on.
        :param max_sentence_length: Maximum number of tokens per sentence.
        :param is_test_bot: If True, this bot it can be chosen from the list of
            bots you see at /dialogadmins screen, but will never be randomly
            assigned to users landing on the home page.
        """
        self.id = id
        self.name = name
        self.checkpoint_path = checkpoint_path
        self.max_conversation_length = max_conversation_length
        self.max_sentence_length = max_sentence_length
        self.is_test_bot = is_test_bot
        self.safe_mode = safe_mode

        print("\n\nCreating chatbot", name)

        self.config = get_config_from_dir(checkpoint_path,
                                          mode='test',
                                          load_rl_ckpt=rl)
        self.config.beam_size = 5

        print('Loading Vocabulary...')
        self.vocab = Vocab()
        self.vocab.load(self.config.word2id_path, self.config.id2word_path)
        print(f'Vocabulary size: {self.vocab.vocab_size}')

        self.config.vocab_size = self.vocab.vocab_size

        # If checkpoint is for an emotion model, load that pickle file
        emotion_sentences = None
        if self.config.emotion:
            emotion_sentences = load_pickle(self.config.emojis_path)

        # Load infersent embeddings if necessary
        infersent_sentences = None
        if self.config.infersent:
            print('Loading infersent sentence embeddings...')
            infersent_sentences = load_pickle(self.config.infersent_path)
            embedding_size = infersent_sentences[0][0].shape[0]
            self.config.infersent_output_size = embedding_size

        self.data_loader = get_loader(
            sentences=load_pickle(self.config.sentences_path),
            conversation_length=load_pickle(
                self.config.conversation_length_path),
            sentence_length=load_pickle(self.config.sentence_length_path),
            vocab=self.vocab,
            batch_size=self.config.batch_size,
            emojis=emotion_sentences)

        if self.config.model in VariationalModels:
            self.solver = VariationalSolver(self.config,
                                            None,
                                            self.data_loader,
                                            vocab=self.vocab,
                                            is_train=False)
        elif self.config.model == 'Transformer':
            self.solver = ParlAISolver(self.config)
        else:
            self.solver = Solver(self.config,
                                 None,
                                 self.data_loader,
                                 vocab=self.vocab,
                                 is_train=False)

        self.solver.build()

    def handle_messages(self, messages):
        """
        Takes a list of messages, and combines those with magic to return a response string
        :param messages: list of strings
        :return: string
        """
        greetings = [
            "hey , how are you ?", "hi , how 's it going ?",
            "hey , what 's up ?", "hi . how are you ?",
            "hello , how are you doing today ? ",
            "hello . how are things with you ?",
            "hey ! so, tell me about yourself .", "hi . nice to meet you ."
        ]

        # Check for no response
        if len(messages) == 0:
            # Respond with canned greeting response
            return np.random.choice(greetings)

        # Check for overly short intro messages
        if len(messages) < 2 and len(messages[0]) <= 6:  # 6 for "hello."
            first_m = messages[0].lower()
            if 'hi' in first_m or 'hey' in first_m or 'hello' in first_m:
                # Respond with canned greeting response
                return np.random.choice(greetings)

        response = self.solver.generate_response_to_input(
            messages,
            max_conversation_length=self.max_conversation_length,
            emojize=True,
            debug=False)

        # Manually remove inappropriate language from response.
        # WARNING: the following code contains inappropriate language
        if self.safe_mode:
            response = response.replace("f*g", "<unknown>")
            response = response.replace("gays", "<unknown>")
            response = response.replace("c**t", "%@#$")
            response = response.replace("f**k", "%@#$")
            response = response.replace("shit", "%@#$")
            response = response.replace("dyke", "%@#$")
            response = response.replace("hell", "heck")
            response = response.replace("dick", "d***")
            response = response.replace("bitch", "%@#$")

        return response
コード例 #4
0
ファイル: dbcq.py プロジェクト: stjordanis/neural_chat
class DBCQ:
    def __init__(self, prior_config, rl_config, beam_size=5):
        self.prior_config = prior_config
        self.rl_config = rl_config
        self.rl_config.beam_size = beam_size

        print('Loading Vocabulary...')
        self.vocab = Vocab()
        self.vocab.load(prior_config.word2id_path, prior_config.id2word_path)
        self.prior_config.vocab_size = self.vocab.vocab_size
        self.rl_config.vocab_size = self.vocab.vocab_size
        print(f'Vocabulary size: {self.vocab.vocab_size}')
        
        self.eval_data = self.get_data_loader()
        self.build_models()

    def build_models(self):
        rl_config = copy.deepcopy(self.rl_config)
        rl_config.checkpoint = None  

        print('Building Q network')
        if rl_config.model in VariationalModels:
            self.q_net = VariationalSolver(
                rl_config, None, self.eval_data, vocab=self.vocab, 
                is_train=False)
        else:
            self.q_net = Solver(
                rl_config, None, self.eval_data, vocab=self.vocab, 
                is_train=False)
        self.q_net.build()
        self.load_q_network()

        print('Building prior network')
        if self.prior_config.model in VariationalModels:
            self.pretrained_prior = VariationalSolver(
                self.prior_config, None, self.eval_data, vocab=self.vocab, 
                is_train=False)
        else:
            self.pretrained_prior = Solver(
                self.prior_config, None, self.eval_data, vocab=self.vocab, 
                is_train=False)
        self.pretrained_prior.build()
       
        # Freeze the weights so they stay constant
        self.pretrained_prior.model.eval()
        for params in self.pretrained_prior.model.parameters():
            params.requires_grad = False
        self.q_net.model.eval()
        for params in self.q_net.model.parameters():
            params.requires_grad = False

    def load_q_network(self):
        """Load parameters from RL checkpoint"""
        print(f'Loading parameters for Q net from {self.rl_config.checkpoint}')
        q_ckpt = torch.load(self.rl_config.checkpoint)
        q_ckpt = convert_old_checkpoint_format(q_ckpt)
        self.q_net.model.load_state_dict(q_ckpt)

        # Ensure weights are initialized to be on the GPU when necessary
        if torch.cuda.is_available():
            print('Converting checkpointed model to cuda tensors')
            self.q_net.model.cuda()

    def get_data_loader(self):
        # If checkpoint is for an emotion model, load that pickle file
        emotion_sentences = None
        if self.prior_config.emotion:
            emotion_sentences = load_pickle(self.prior_config.emojis_path)

        # Load infersent embeddings if necessary
        infersent_sentences = None
        if self.prior_config.infersent:
            print('Loading infersent sentence embeddings...')
            infersent_sentences = load_pickle(self.prior_config.infersent_path)
            embedding_size = infersent_sentences[0][0].shape[0]
            self.prior_config.infersent_output_size = embedding_size

        return get_loader(
            sentences=load_pickle(self.prior_config.sentences_path),
            conversation_length=load_pickle(
                self.prior_config.conversation_length_path),
            sentence_length=load_pickle(self.prior_config.sentence_length_path),
            vocab=self.vocab,
            batch_size=self.prior_config.batch_size,
            emojis=emotion_sentences,
            infersent=infersent_sentences)

    def interact(self, max_conversation_length=5, sample_by='priority', 
                 debug=True):
        model_name = self.prior_config.model
        context_sentences = []

        print("Time to start a conversation with the chatbot! It's name is", 
              model_name)
        username = input("What is your name? ")

        print("Let's start chatting. You can type 'quit' at any time to quit.")
        utterance = input("Input: ")
        print("\033[1A\033[K")  # Erases last line of output

        while (utterance.lower() != 'quit' and utterance.lower() != 'exit'):
            # Process utterance
            sentences = utterance.split('/')

            # Code and decode user input to show how it is transformed for model
            coded, lens = self.pretrained_prior.process_user_input(sentences)
            decoded = [self.vocab.decode(sent) for i, sent in enumerate(
                coded) if i < lens[i]]
            print(username + ':', '. '.join(decoded))

            # Append to conversation
            context_sentences.extend(sentences)

            gen_response = self.generate_response_to_input(
                context_sentences, max_conversation_length, sample_by=sample_by,
                debug=debug)

            # Append generated sentences to conversation
            context_sentences.append(gen_response)

            # Print and get next user input
            print("\n" + model_name + ": " + gen_response)
            utterance = input("Input: ")
            print("\033[1A\033[K")

    def process_raw_text_into_input(self, raw_text_sentences,
                                    max_conversation_length=5, debug=False,):
        sentences, lens = self.pretrained_prior.process_user_input(
            raw_text_sentences, self.rl_config.max_sentence_length)

        # Remove any sentences of length 0
        sentences = [sent for i, sent in enumerate(sentences) if lens[i] > 0]
        good_raw_sentences = [sent for i, sent in enumerate(
            raw_text_sentences) if lens[i] > 0]
        lens = [l for l in lens if l > 0]

        # Trim conversation to max length
        sentences = sentences[-max_conversation_length:]
        lens = lens[-max_conversation_length:]
        good_raw_sentences = good_raw_sentences[-max_conversation_length:]
        convo_length = len(sentences)

        # Convert to torch variables
        input_sentences = to_var(torch.LongTensor(sentences))
        input_sentence_length = to_var(torch.LongTensor(lens))
        input_conversation_length = to_var(torch.LongTensor([convo_length]))

        if debug:
            print('\n**Conversation history:**')
            for sent in sentences:
                print(self.vocab.decode(list(sent)))

        return (input_sentences, input_sentence_length, 
                input_conversation_length)

    def duplicate_context_for_beams(self, sentences, sent_lens, conv_lens,
                                    beams):
        conv_lens = conv_lens.repeat(len(beams))
        # [beam_size * sentences, sentence_len]
        if len(sentences) > 1:
            targets = torch.cat(
                [torch.cat([sentences[1:,:], beams[i,:].unsqueeze(0)], 0) 
                for i in range(len(beams))], 0)
        else:
            targets = beams

        # HRED
        if self.rl_config.model not in VariationalModels:
            sent_lens = sent_lens.repeat(len(beams))
            return sentences, sent_lens, conv_lens, targets
        
        # VHRED, VHCR
        new_sentences = torch.cat(
            [torch.cat([sentences, beams[i,:].unsqueeze(0)], 0) 
            for i in range(len(beams))], 0)
        new_len = to_var(torch.LongTensor([self.rl_config.max_sentence_length]))
        sent_lens = torch.cat(
            [torch.cat([sent_lens, new_len], 0) for i in range(len(beams))])
        return new_sentences, sent_lens, conv_lens, targets

    def generate_response_to_input(self, raw_text_sentences, 
                                   max_conversation_length=5,
                                   sample_by='priority', emojize=True,
                                   debug=True):
        with torch.no_grad():
            (input_sentences, input_sent_lens, 
             input_conv_lens) = self.process_raw_text_into_input(
                raw_text_sentences, debug=debug,
                max_conversation_length=max_conversation_length)

            # Initialize a tensor for beams
            beams = to_var(torch.LongTensor(
                np.ones((self.rl_config.beam_size, 
                        self.rl_config.max_sentence_length))))

            # Create a batch with the context duplicated for each beam
            (sentences, sent_lens, 
            conv_lens, targets) = self.duplicate_context_for_beams(
                input_sentences, input_sent_lens, input_conv_lens, beams)

            # Continuously feed beam sentences into networks to sample the next 
            # best word, add that to the beam, and continue
            for i in range(self.rl_config.max_sentence_length):
                # Run both models to obtain logits
                prior_output = self.pretrained_prior.model(
                    sentences, sent_lens, conv_lens, targets, rl_mode=True)
                all_prior_logits = prior_output[0]
                
                q_output = self.q_net.model(
                    sentences, sent_lens, conv_lens, targets, rl_mode=True)
                all_q_logits = q_output[0]

                # Select only those logits for next word
                q_logits = all_q_logits[:, i, :].squeeze()
                prior_logits = all_prior_logits[:, i, :].squeeze()

                # Get prior distribution for next word in each beam
                prior_dists = torch.nn.functional.softmax(prior_logits, 1)

                for b in range(self.rl_config.beam_size):
                    # Sample from the prior bcq_n times for each beam
                    dist = torch.distributions.Categorical(prior_dists[b,:])
                    sampled_idxs = dist.sample_n(self.rl_config.bcq_n)

                    # Select sample with highest q value
                    q_vals = torch.stack(
                        [q_logits[b, idx] for idx in sampled_idxs])
                    _, best_word_i = torch.max(q_vals, 0) 
                    best_word = sampled_idxs[best_word_i]

                    # Update beams
                    beams[b, i] = best_word

                (sentences, sent_lens, 
                 conv_lens, targets) = self.duplicate_context_for_beams(
                    input_sentences, input_sent_lens, input_conv_lens, beams)
            
            generated_sentences = beams.cpu().numpy()

        if debug:
            print('\n**All generated responses:**')
            for gen in generated_sentences:
                print(detokenize(self.vocab.decode(list(gen))))
        
        gen_response = self.pretrained_prior.select_best_generated_response(
            generated_sentences, sample_by, beam_size=self.rl_config.beam_size)

        decoded_response = self.vocab.decode(list(gen_response))
        decoded_response = detokenize(decoded_response)

        if emojize:
            inferred_emojis = self.pretrained_prior.botmoji.emojize_text(
                raw_text_sentences[-1], 5, 0.07)
            decoded_response = inferred_emojis + " " + decoded_response
        
        return decoded_response