def interactive(self): from prompt_toolkit import prompt from prompt_toolkit.history import FileHistory from topic_model.lda import TopicInferer, DEFAULT_SEPARATOR from util.nlp import NLPToolkit nlp_toolkit = NLPToolkit() self._pre_model_creation() topic_inferer = TopicInferer(self.config.lda_model_dir) infer_model = taware_helper.create_infer_model(taware_model.TopicAwareSeq2SeqModel, self.config) config_proto = models.model_helper.get_config_proto(self.config.log_device) with tf.Session(graph=infer_model.graph, config=config_proto) as sess: ckpt = tf.train.latest_checkpoint(self.config.model_dir) loaded_infer_model = model_helper.load_model( infer_model.model, ckpt, sess, "infer") log.print_out("# Start decoding") sentence = prompt(">>> ", history=FileHistory(os.path.join(self.config.model_dir, ".chat_history"))).strip() while sentence: utterance = ' '.join(nlp_toolkit.tokenize(sentence)).lower() topic_words = topic_inferer.from_collection([utterance], dialogue_as_doc=True, words_per_topic=self.config.topic_words_per_utterance) iterator_feed_dict = { infer_model.src_placeholder: [utterance + DEFAULT_SEPARATOR + " ".join(topic_words)], infer_model.batch_size_placeholder: 1, } sess.run(infer_model.iterator.initializer, feed_dict=iterator_feed_dict) output, _ = loaded_infer_model.decode(sess) if self.config.beam_width > 0: # get the top translation. output = output[0] resp = ncm_utils.get_translation(output, sent_id=0) log.print_out(resp + b"\n") sentence = prompt(">>> ", history=FileHistory(os.path.join(self.config.model_dir, ".chat_history"))).strip()
def interactive(self): from prompt_toolkit import prompt from prompt_toolkit.history import FileHistory from util.nlp import NLPToolkit nlp_toolkit = NLPToolkit() infer_model = vanilla_helper.create_infer_model(self.config) config_proto = model_helper.get_config_proto(self.config.log_device) with tf.Session(graph=infer_model.graph, config=config_proto) as sess: ckpt = tf.train.latest_checkpoint(self.config.model_dir) loaded_infer_model = model_helper.load_model( infer_model.model, ckpt, sess, "infer") log.print_out("# Start decoding") sentence = prompt(">>> ", history=FileHistory( os.path.join(self.config.model_dir, ".chat_history"))).strip() while sentence: utterance = ' '.join(nlp_toolkit.tokenize(sentence)).lower() iterator_feed_dict = { infer_model.src_placeholder: [utterance], infer_model.batch_size_placeholder: 1, } sess.run(infer_model.iterator.initializer, feed_dict=iterator_feed_dict) output, _ = loaded_infer_model.decode(sess) if self.config.beam_width > 0: # get the top translation. output = output[0] resp = ncm_utils.get_translation(output, sent_id=0) log.print_out(resp + b"\n") sentence = prompt(">>> ", history=FileHistory( os.path.join(self.config.model_dir, ".chat_history"))).strip()
def chat(checkpoint, chat_logs_output_file, hparams, scope=None): # Containing the graph, model, source placeholder, batch_size placeholder and iterator if hparams.architecture == "simple": model_creator = SimpleModel get_infer_iterator = iterator_utils.get_infer_iterator def update_dialogue(so_far, utterance, response): # Dialogue so far does not matter return utterance elif hparams.architecture == "hier": model_creator = HierarchicalModel # Parse some of the arguments now get_infer_iterator = lambda dataset, vocab_table, batch_size, src_reverse, eos, src_max_len: \ end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos, src_max_len=src_max_len, eou=hparams.eou, dialogue_max_len=hparams.dialogue_max_len) def update_dialogue(so_far, utterance, response): # Dialogue so far is considered if response == "": return utterance return so_far + " " + hparams.eou + " " + response + " " + hparams.eou + " " + utterance else: raise ValueError("Unkown architecture", hparams.architecture) infer_model = create_infer_model(model_creator, get_infer_iterator, hparams=hparams, verbose=False, scope=scope) with tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) as sess: # Load the model from the checkpoint # ToDo: adapt for architectures loaded_infer_model = model_helper.load_model(model=infer_model.model, ckpt=checkpoint, session=sess, name="infer", verbose=False) utils.print_out("I'm Batman!") dialogue_so_far = "" response = "" # Gather all the user's utterances user_utterances = [] # Leave it in chat mode until interrupted while True: # Read utterance from user. utterance, original_utterance = get_user_input(hparams) # Check if we should finish if utterance == 'end()': overall_score(user_utterances, hparams) break user_utterances.append(original_utterance) # Preprocess it into the familiar format for the machine utterance = preprocessing_utils.tokenize_line( utterance, number_token=hparams.number_token, name_token=hparams.name_token, gpe_token=hparams.gpe_token) dialogue_so_far = update_dialogue(dialogue_so_far, utterance, response) # Transform it into a batch of size 1 batched_dialogue = [dialogue_so_far] # Initialize the iterator sess.run(infer_model.iterator.initializer, feed_dict={ infer_model.src_placeholder: batched_dialogue, infer_model.batch_size_placeholder: 1 }) response = chatbot_utils.decode_utterance( model=loaded_infer_model, sess=sess, output_file=chat_logs_output_file, bpe_delimiter=hparams.bpe_delimiter, beam_width=hparams.beam_width, utterance=utterance, top_responses=hparams.top_responses, eos=hparams.eos, number_token=hparams.number_token, name_token=hparams.name_token)
def inference(checkpoint, inference_input_file, inference_output_file, hparams, scope=None): """Create the responses.""" # TODO: can add multiple number of workers if so desired. This function is the helper _single_worker_inference # in the original model # Read the data infer_data = load_data(inference_input_file, hparams) if hparams.architecture == "simple": model_creator = SimpleModel get_infer_iterator = iterator_utils.get_infer_iterator elif hparams.architecture == "hier": model_creator = HierarchicalModel # Parse some of the arguments now get_infer_iterator = lambda dataset, vocab_table, batch_size, src_reverse, eos, src_max_len: \ end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos, src_max_len=src_max_len, eou=hparams.eou, dialogue_max_len=hparams.dialogue_max_len) else: raise ValueError("Unkown architecture", hparams.architecture) # Containing the graph, model, source placeholder, batch_size placeholder and iterator infer_model = create_infer_model(model_creator, get_infer_iterator, hparams, scope) # ToDo: adapt for architectures with tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) as sess: # Load the model from the checkpoint loaded_infer_model = model_helper.load_model(model=infer_model.model, ckpt=checkpoint, session=sess, name="infer") # Initialize the iterator sess.run(infer_model.iterator.initializer, feed_dict={ infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: hparams.infer_batch_size }) # Decode utils.print_out("# Starting Decoding") # Decode only a specific set of indices if hparams.inference_indices: _decode_inference_indices( model=loaded_infer_model, sess=sess, output_infer_file=inference_output_file, output_infer_summary_prefix=inference_output_file, inference_indices=hparams.inference_indices, eos=hparams.eos, bpe_delimiter=hparams.bpe_delimiter, number_token=hparams.number_token, name_token=hparams.name_token) else: chatbot_utils.decode_and_evaluate( name="infer", model=loaded_infer_model, sess=sess, output_file=inference_output_file, reference_file=None, metrics=hparams.metrics, bpe_delimiter=hparams.bpe_delimiter, beam_width=hparams.beam_width, eos=hparams.eos, number_token=hparams.number_token, name_token=hparams.name_token)
def test(self): start_test_time = time.time() assert self.config.n_responses >= 1 if self.config.beam_width > 0: assert self.config.n_responses <= self.config.beam_width else: assert self.config.n_responses == 1 self._pre_model_creation() infer_model = self._get_model_helper().create_infer_model(self.config) latest_ckpt = tf.train.latest_checkpoint(self.config.get_infer_model_dir()) with tf.Session( config=model_helper.get_config_proto(self.config.log_device), graph=infer_model.graph) as infer_sess: loaded_infer_model = model_helper.load_model( infer_model.model, latest_ckpt, infer_sess, "infer") log.print_out("# Start decoding") log.print_out(" beam width: {}".format(self.config.beam_width)) log.print_out(" length penalty: {}".format(self.config.length_penalty_weight)) log.print_out(" sampling temperature: {}".format(self.config.sampling_temperature)) log.print_out(" num responses per test instance: {}".format(self.config.n_responses)) feed_dict = { infer_model.src_placeholder: self._load_data(self.config.test_data), infer_model.batch_size_placeholder: self.config.infer_batch_size, } if self.config.sampling_temperature > 0: label = "%s_t%.1f" % ( fs.file_name(self.config.test_data), self.config.sampling_temperature) else: label = "%s_bw%d_lp%.1f" % ( fs.file_name(self.config.test_data), self.config.beam_width, self.config.length_penalty_weight) self._decode_and_evaluate(loaded_infer_model, infer_sess, feed_dict, label=label, num_responses_per_input=self.config.n_responses) log.print_time("# Decoding done", start_test_time) eval_model = self._get_model_helper().create_eval_model(self.config) with tf.Session( config=model_helper.get_config_proto(self.config.log_device), graph=eval_model.graph) as eval_sess: loaded_eval_model = model_helper.load_model( eval_model.model, latest_ckpt, eval_sess, "eval") log.print_out("# Compute Perplexity") dev_eval_iterator_feed_dict = { eval_model.eval_file_placeholder: self.config.test_data } eval_sess.run(eval_model.iterator.initializer, feed_dict=dev_eval_iterator_feed_dict) model_helper.compute_perplexity(loaded_eval_model, eval_sess, "test") log.print_time("# Test finished", start_test_time)
def interactive(self, scope=None): import platform from prompt_toolkit import prompt from prompt_toolkit.history import FileHistory from topic_model.lda import TopicInferer, DEFAULT_SEPARATOR from util.nlp import NLPToolkit nlp_toolkit = NLPToolkit() __os = platform.system() self.config.infer_batch_size = 1 self._pre_model_creation() if self.config.lda_model_dir is not None: topic_inferer = TopicInferer(self.config.lda_model_dir) else: topic_inferer = None infer_model = self._get_model_helper().create_infer_model(self.config, scope) with tf.Session( config=model_helper.get_config_proto(self.config.log_device), graph=infer_model.graph) as sess: latest_ckpt = tf.train.latest_checkpoint(self.config.model_dir) loaded_infer_model = model_helper.load_model( infer_model.model, latest_ckpt, sess, "infer") log.print_out("# Start decoding") if __os == 'Windows': sentence = input("> ").strip() else: sentence = prompt(">>> ", history=FileHistory(os.path.join(self.config.model_dir, ".chat_history"))).strip() conversation = [vocab.EOS] * (self.config.num_turns - 1) while sentence: current_utterance = ' '.join(nlp_toolkit.tokenize(sentence)).lower() conversation.append(current_utterance) conversation.pop(0) feedable_context = "\t".join(conversation) if topic_inferer is None: iterator_feed_dict = { infer_model.src_placeholder: [feedable_context], infer_model.batch_size_placeholder: 1, } else: topic_words = topic_inferer.from_collection([feedable_context], dialogue_as_doc=True, words_per_topic=self.config.topic_words_per_utterance) iterator_feed_dict = { infer_model.src_placeholder: [feedable_context + DEFAULT_SEPARATOR + " ".join(topic_words)], infer_model.batch_size_placeholder: 1, } sess.run(infer_model.iterator.initializer, feed_dict=iterator_feed_dict) output, infer_summary = loaded_infer_model.decode(sess) if self.config.beam_width > 0 and self._consider_beam(): # get the top translation. output = output[0] resp = ncm_utils.get_translation(output, sent_id=0) log.print_out(resp + b"\n") if __os == 'Windows': sentence = input("> ").strip() else: sentence = prompt(">>> ", history=FileHistory(os.path.join(self.config.model_dir, ".chat_history"))).strip() print("Bye!!!")
def test(self): start_test_time = time.time() assert self.config.n_responses >= 1 if self.config.beam_width > 0: assert self.config.n_responses <= self.config.beam_width else: assert self.config.n_responses == 1 self._pre_model_creation() infer_model = taware_helper.create_infer_model(taware_model.TopicAwareSeq2SeqModel, self.config) config_proto = models.model_helper.get_config_proto(self.config.log_device) ckpt = tf.train.latest_checkpoint(self.config.get_infer_model_dir()) with tf.Session(graph=infer_model.graph, config=config_proto) as infer_sess: loaded_infer_model = model_helper.load_model( infer_model.model, ckpt, infer_sess, "infer") log.print_out("# Start decoding") log.print_out(" beam width: {}".format(self.config.beam_width)) log.print_out(" length penalty: {}".format(self.config.length_penalty_weight)) log.print_out(" sampling temperature: {}".format(self.config.sampling_temperature)) log.print_out(" num responses per test instance: {}".format(self.config.n_responses)) feed_dict = { infer_model.src_placeholder: self._load_data(self.config.test_data), infer_model.batch_size_placeholder: self.config.infer_batch_size, } infer_sess.run(infer_model.iterator.initializer, feed_dict=feed_dict) if self.config.sampling_temperature > 0: label = "%s_t%.1f" % ( fs.file_name(self.config.test_data), self.config.sampling_temperature) else: label = "%s_bw%d_lp%.1f" % ( fs.file_name(self.config.test_data), self.config.beam_width, self.config.length_penalty_weight) out_file = os.path.join(self.config.model_dir, "output_{}".format(label)) eval_metric.decode_and_evaluate( "test", loaded_infer_model, infer_sess, out_file, ref_file=None, metrics=self.config.metrics, beam_width=self.config.beam_width, num_translations_per_input=self.config.n_responses) log.print_time("# Decoding done", start_test_time) eval_model = taware_helper.create_eval_model(taware_model.TopicAwareSeq2SeqModel, self.config) with tf.Session( config=models.model_helper.get_config_proto(self.config.log_device), graph=eval_model.graph) as eval_sess: loaded_eval_model = model_helper.load_model( eval_model.model, ckpt, eval_sess, "eval") log.print_out("# Compute Perplexity") feed_dict = { eval_model.eval_file_placeholder: self.config.test_data } eval_sess.run(eval_model.iterator.initializer, feed_dict=feed_dict) model_helper.compute_perplexity(loaded_eval_model, eval_sess, "test") log.print_time("# Test finished", start_test_time)
args['length_penalty_weight'] = 0.8 config = Config(**args) stop_words = load_stop_words(config.model_dir) word2label = load_keyword_label(config.model_dir) label2topn = load_top_words(config.model_dir) model = model_factory.create_model(config) infer_model = model.create_infer_model_graph() config_proto = model_helper.get_config_proto(config.log_device) sess = tf.InteractiveSession(graph=infer_model.graph, config=config_proto) ckpt = tf.train.latest_checkpoint(config.model_dir) loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, sess, "infer") def extract_keyword(query): tags = jieba.analyse.extract_tags(query, topK=10) keyword = [] for w in tags: if re.search('^\d+$', w): continue if re.search('^[a-zA-Z]+$', w): continue if w in stop_words: continue keyword.append(w) if not keyword: return None, None