def __init__(self, session, corpus_dir, knbase_dir, result_dir, result_file): """ Args: session: The TensorFlow session. corpus_dir: Name of the folder storing corpus files and vocab information. knbase_dir: Name of the folder storing data files for the knowledge base. result_dir: The folder containing the trained result files. result_file: The file name of the trained model. """ self.session = session # Prepare data and hyper parameters print("# Prepare dataset placeholder and hyper parameters ...") tokenized_data = TokenizedData(corpus_dir=corpus_dir, training=False) self.knowledge_base = KnowledgeBase() self.knowledge_base.load_knbase(knbase_dir) self.session_data = SessionData() self.hparams = tokenized_data.hparams self.src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) src_dataset = tf.data.Dataset.from_tensor_slices(self.src_placeholder) self.infer_batch = tokenized_data.get_inference_batch(src_dataset) # Create model print("# Creating inference model ...") self.model = ModelCreator(training=False, tokenized_data=tokenized_data, batch_input=self.infer_batch) # Restore model weights print("# Restoring model weights ...") self.model.saver.restore(session, os.path.join(result_dir, result_file)) self.session.run(tf.tables_initializer())
def __init__(self, corpus_dir, hparams=None, knbase_dir=None, training=True, augment_factor=3, buffer_size=8192): """ Args: corpus_dir: Name of the folder storing corpus files for training. hparams: The object containing the loaded hyper parameters. If None, it will be initialized here. knbase_dir: Name of the folder storing data files for the knowledge base. Used for inference only. training: Whether to use this object for training. augment_factor: Times the training data appears. If 1 or less, no augmentation. buffer_size: The buffer size used for mapping process during data processing. """ if hparams is None: self.hparams = HParams(corpus_dir).hparams else: self.hparams = hparams self.src_max_len = self.hparams.src_max_len self.tgt_max_len = self.hparams.tgt_max_len self.training = training self.text_set = None self.id_set = None vocab_file = os.path.join(corpus_dir, VOCAB_FILE) self.vocab_size, _ = check_vocab(vocab_file) self.vocab_table = lookup_ops.index_table_from_file( vocab_file, default_value=self.hparams.unk_id) # print("vocab_size = {}".format(self.vocab_size)) if training: self.case_table = prepare_case_table() self.reverse_vocab_table = None self._load_corpus(corpus_dir, augment_factor) self._convert_to_tokens(buffer_size) self.upper_words = {} self.stories = {} self.jokes = [] else: self.case_table = None self.reverse_vocab_table = \ lookup_ops.index_to_string_table_from_file(vocab_file, default_value=self.hparams.unk_token) assert knbase_dir is not None knbs = KnowledgeBase() knbs.load_knbase(knbase_dir) self.upper_words = knbs.upper_words self.stories = knbs.stories self.jokes = knbs.jokes
def __init__(self, session, corpus_dir, knbase_dir, result_dir, aiml_dir, result_file): """ Args: session: The TensorFlow session. corpus_dir: Name of the folder storing corpus files and vocab information. knbase_dir: Name of the folder storing data files for the knowledge base. result_dir: The folder containing the trained result files. result_file: The file name of the trained model. """ self.session = session # Prepare data and hyper parameters print("# Prepare dataset placeholder and hyper parameters ...") tokenized_data = TokenizedData(corpus_dir=corpus_dir, training=False) self.knowledge_base = KnowledgeBase() self.knowledge_base.load_knbase(knbase_dir) self.session_data = SessionData() self.hparams = tokenized_data.hparams self.src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) src_dataset = tf.data.Dataset.from_tensor_slices(self.src_placeholder) self.infer_batch = tokenized_data.get_inference_batch(src_dataset) # Create Retrival model self.kmodel = aiml.Kernel() brain_file_name = os.path.join(aiml_dir, BRAIN_FILE) print(aiml_dir) # Restore model rules if os.path.exists(brain_file_name): print("# Loading from brain file ... ") self.kmodel.loadBrain(brain_file_name) else: print("# Parsing aiml files ...") aimls_file_name = os.path.join(aiml_dir, AIMLS_FILE) self.kmodel.bootstrap(learnFiles=os.path.abspath(aimls_file_name), commands="load aiml b") print("# Saving brain file: " + BRAIN_FILE) self.kmodel.saveBrain(brain_file_name) # Create Generative model print("# Creating inference model ...") self.model = ModelCreator(training=False, tokenized_data=tokenized_data, batch_input=self.infer_batch) # Restore model weights print("# Restoring model weights ...") self.model.saver.restore(session, os.path.join(result_dir, result_file)) self.session.run(tf.tables_initializer())
class BotPredictor(object): def __init__(self, session, corpus_dir, knbase_dir, result_dir, result_file): """ Args: session: The TensorFlow session. corpus_dir: Name of the folder storing corpus files and vocab information. knbase_dir: Name of the folder storing data files for the knowledge base. result_dir: The folder containing the trained result files. result_file: The file name of the trained model. """ self.session = session # Prepare data and hyper parameters print("# Prepare dataset placeholder and hyper parameters ...") tokenized_data = TokenizedData(corpus_dir=corpus_dir, training=False) self.knowledge_base = KnowledgeBase() self.knowledge_base.load_knbase(knbase_dir) self.session_data = SessionData() self.hparams = tokenized_data.hparams self.src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) src_dataset = tf.data.Dataset.from_tensor_slices(self.src_placeholder) self.infer_batch = tokenized_data.get_inference_batch(src_dataset) # Create model print("# Creating inference model ...") self.model = ModelCreator(training=False, tokenized_data=tokenized_data, batch_input=self.infer_batch) # Restore model weights print("# Restoring model weights ...") self.model.saver.restore(session, os.path.join(result_dir, result_file)) self.session.run(tf.tables_initializer()) def predict(self, session_id, question, html_format=False): chat_session = self.session_data.get_session(session_id) chat_session.before_prediction() # Reset before each prediction if question.strip() == '': answer = "Don't you want to say something to me?" chat_session.after_prediction(question, answer) return answer pat_matched, new_sentence, para_list = check_patterns_and_replace(question) for pre_time in range(2): tokens = nltk.word_tokenize(new_sentence.lower()) tmp_sentence = [' '.join(tokens[:]).strip()] self.session.run(self.infer_batch.initializer, feed_dict={self.src_placeholder: tmp_sentence}) outputs, _ = self.model.infer(self.session) if self.hparams.beam_width > 0: outputs = outputs[0] eos_token = self.hparams.eos_token.encode("utf-8") outputs = outputs.tolist()[0] if eos_token in outputs: outputs = outputs[:outputs.index(eos_token)] if pat_matched and pre_time == 0: out_sentence, if_func_val = self._get_final_output(outputs, chat_session, para_list=para_list, html_format=html_format) if if_func_val: chat_session.after_prediction(question, out_sentence) return out_sentence else: new_sentence = question else: out_sentence, _ = self._get_final_output(outputs, chat_session, html_format=html_format) chat_session.after_prediction(question, out_sentence) return out_sentence def _get_final_output(self, sentence, chat_session, para_list=None, html_format=False): sentence = b' '.join(sentence).decode('utf-8') if sentence == '': return "I don't know what to say.", False if_func_val = False last_word = None word_list = [] for word in sentence.split(' '): word = word.strip() if not word: continue if word.startswith('_func_val_'): if_func_val = True word = call_function(word[10:], knowledge_base=self.knowledge_base, chat_session=chat_session, para_list=para_list, html_format=html_format) if word is None or word == '': continue else: if word in self.knowledge_base.upper_words: word = self.knowledge_base.upper_words[word] if (last_word is None or last_word in ['.', '!', '?']) and not word[0].isupper(): word = word.capitalize() if not word.startswith('\'') and word != 'n\'t' \ and (word[0] not in string.punctuation or word in ['(', '[', '{', '``', '$']) \ and last_word not in ['(', '[', '{', '``', '$']: word = ' ' + word word_list.append(word) last_word = word return ''.join(word_list).strip(), if_func_val