Ejemplo n.º 1
0
    def __init__(self,
                 corpus_dir,
                 hparams=None,
                 knbase_dir=None,
                 training=True,
                 augment_factor=3,
                 buffer_size=8192):
        """
        Args:
            corpus_dir: Name of the folder storing corpus files for training.
            hparams: The object containing the loaded hyper parameters. If None, it will be 
                    initialized here.
            knbase_dir: Name of the folder storing data files for the knowledge base. Used for 
                    inference only.
            training: Whether to use this object for training.
            augment_factor: Times the training data appears. If 1 or less, no augmentation.
            buffer_size: The buffer size used for mapping process during data processing.
        """
        if hparams is None:
            self.hparams = HParams(corpus_dir).hparams
        else:
            self.hparams = hparams

        self.src_max_len = self.hparams.src_max_len
        self.tgt_max_len = self.hparams.tgt_max_len

        self.training = training
        self.text_set = None
        self.id_set = None

        vocab_file = os.path.join(corpus_dir, VOCAB_FILE)
        self.vocab_size, _ = check_vocab(vocab_file)
        self.vocab_table = lookup_ops.index_table_from_file(
            vocab_file, default_value=self.hparams.unk_id)
        # print("vocab_size = {}".format(self.vocab_size))

        if training:
            self.case_table = prepare_case_table()
            self.reverse_vocab_table = None
            self._load_corpus(corpus_dir, augment_factor)
            self._convert_to_tokens(buffer_size)

            self.upper_words = {}
            self.stories = {}
            self.jokes = []
        else:
            self.case_table = None
            self.reverse_vocab_table = \
                lookup_ops.index_to_string_table_from_file(vocab_file,
                                                           default_value=self.hparams.unk_token)
            assert knbase_dir is not None
            knbs = KnowledgeBase()
            knbs.load_knbase(knbase_dir)
            self.upper_words = knbs.upper_words
            self.stories = knbs.stories
            self.jokes = knbs.jokes
Ejemplo n.º 2
0
class BotPredictor(object):
    def __init__(self, session, corpus_dir, knbase_dir, result_dir, result_file):
        """
        Args:
            session: The TensorFlow session.
            corpus_dir: Name of the folder storing corpus files and vocab information.
            knbase_dir: Name of the folder storing data files for the knowledge base.
            result_dir: The folder containing the trained result files.
            result_file: The file name of the trained model.
        """
        self.session = session

        # Prepare data and hyper parameters
        print("# Prepare dataset placeholder and hyper parameters ...")
        tokenized_data = TokenizedData(corpus_dir=corpus_dir, training=False)

        self.knowledge_base = KnowledgeBase()
        self.knowledge_base.load_knbase(knbase_dir)

        self.session_data = SessionData()

        self.hparams = tokenized_data.hparams
        self.src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        src_dataset = tf.data.Dataset.from_tensor_slices(self.src_placeholder)
        self.infer_batch = tokenized_data.get_inference_batch(src_dataset)

        # Create model
        print("# Creating inference model ...")
        self.model = ModelCreator(training=False, tokenized_data=tokenized_data,
                                  batch_input=self.infer_batch)
        # Restore model weights
        print("# Restoring model weights ...")
        self.model.saver.restore(session, os.path.join(result_dir, result_file))

        self.session.run(tf.tables_initializer())

    def predict(self, session_id, question, html_format=False):
        chat_session = self.session_data.get_session(session_id)
        chat_session.before_prediction()  # Reset before each prediction

        if question.strip() == '':
            answer = "Don't you want to say something to me?"
            chat_session.after_prediction(question, answer)
            return answer

        pat_matched, new_sentence, para_list = check_patterns_and_replace(question)

        for pre_time in range(2):
            tokens = nltk.word_tokenize(new_sentence.lower())
            tmp_sentence = [' '.join(tokens[:]).strip()]

            self.session.run(self.infer_batch.initializer,
                             feed_dict={self.src_placeholder: tmp_sentence})

            outputs, _ = self.model.infer(self.session)

            if self.hparams.beam_width > 0:
                outputs = outputs[0]

            eos_token = self.hparams.eos_token.encode("utf-8")
            outputs = outputs.tolist()[0]

            if eos_token in outputs:
                outputs = outputs[:outputs.index(eos_token)]

            if pat_matched and pre_time == 0:
                out_sentence, if_func_val = self._get_final_output(outputs, chat_session,
                                                                   para_list=para_list,
                                                                   html_format=html_format)
                if if_func_val:
                    chat_session.after_prediction(question, out_sentence)
                    return out_sentence
                else:
                    new_sentence = question
            else:
                out_sentence, _ = self._get_final_output(outputs, chat_session,
                                                         html_format=html_format)
                chat_session.after_prediction(question, out_sentence)
                return out_sentence

    def _get_final_output(self, sentence, chat_session, para_list=None, html_format=False):
        sentence = b' '.join(sentence).decode('utf-8')
        if sentence == '':
            return "I don't know what to say.", False

        if_func_val = False
        last_word = None
        word_list = []
        for word in sentence.split(' '):
            word = word.strip()
            if not word:
                continue

            if word.startswith('_func_val_'):
                if_func_val = True
                word = call_function(word[10:], knowledge_base=self.knowledge_base,
                                     chat_session=chat_session, para_list=para_list,
                                     html_format=html_format)
                if word is None or word == '':
                    continue
            else:
                if word in self.knowledge_base.upper_words:
                    word = self.knowledge_base.upper_words[word]

                if (last_word is None or last_word in ['.', '!', '?']) and not word[0].isupper():
                    word = word.capitalize()

            if not word.startswith('\'') and word != 'n\'t' \
                and (word[0] not in string.punctuation or word in ['(', '[', '{', '``', '$']) \
                and last_word not in ['(', '[', '{', '``', '$']:
                word = ' ' + word

            word_list.append(word)
            last_word = word

        return ''.join(word_list).strip(), if_func_val