示例#1
0
    def interactive(self):
        from prompt_toolkit import prompt
        from prompt_toolkit.history import FileHistory
        from topic_model.lda import TopicInferer, DEFAULT_SEPARATOR
        from util.nlp import NLPToolkit

        nlp_toolkit = NLPToolkit()

        self._pre_model_creation()
        topic_inferer = TopicInferer(self.config.lda_model_dir)

        infer_model = taware_helper.create_infer_model(taware_model.TopicAwareSeq2SeqModel, self.config)
        config_proto = models.model_helper.get_config_proto(self.config.log_device)

        with tf.Session(graph=infer_model.graph, config=config_proto) as sess:
            ckpt = tf.train.latest_checkpoint(self.config.model_dir)
            loaded_infer_model = model_helper.load_model(
                infer_model.model, ckpt, sess, "infer")

            log.print_out("# Start decoding")

            sentence = prompt(">>> ", history=FileHistory(os.path.join(self.config.model_dir, ".chat_history"))).strip()

            while sentence:
                utterance = ' '.join(nlp_toolkit.tokenize(sentence)).lower()
                topic_words = topic_inferer.from_collection([utterance],
                                                            dialogue_as_doc=True,
                                                            words_per_topic=self.config.topic_words_per_utterance)

                iterator_feed_dict = {
                    infer_model.src_placeholder: [utterance + DEFAULT_SEPARATOR + " ".join(topic_words)],
                    infer_model.batch_size_placeholder: 1,
                }
                sess.run(infer_model.iterator.initializer, feed_dict=iterator_feed_dict)
                output, _ = loaded_infer_model.decode(sess)

                if self.config.beam_width > 0:
                    # get the top translation.
                    output = output[0]

                resp = ncm_utils.get_translation(output, sent_id=0)

                log.print_out(resp + b"\n")

                sentence = prompt(">>> ",
                                  history=FileHistory(os.path.join(self.config.model_dir, ".chat_history"))).strip()
示例#2
0
    def interactive(self):
        from prompt_toolkit import prompt
        from prompt_toolkit.history import FileHistory
        from util.nlp import NLPToolkit

        nlp_toolkit = NLPToolkit()

        infer_model = vanilla_helper.create_infer_model(self.config)
        config_proto = model_helper.get_config_proto(self.config.log_device)

        with tf.Session(graph=infer_model.graph, config=config_proto) as sess:
            ckpt = tf.train.latest_checkpoint(self.config.model_dir)
            loaded_infer_model = model_helper.load_model(
                infer_model.model, ckpt, sess, "infer")

            log.print_out("# Start decoding")

            sentence = prompt(">>> ",
                              history=FileHistory(
                                  os.path.join(self.config.model_dir,
                                               ".chat_history"))).strip()

            while sentence:
                utterance = ' '.join(nlp_toolkit.tokenize(sentence)).lower()

                iterator_feed_dict = {
                    infer_model.src_placeholder: [utterance],
                    infer_model.batch_size_placeholder: 1,
                }

                sess.run(infer_model.iterator.initializer,
                         feed_dict=iterator_feed_dict)
                output, _ = loaded_infer_model.decode(sess)

                if self.config.beam_width > 0:
                    # get the top translation.
                    output = output[0]

                resp = ncm_utils.get_translation(output, sent_id=0)

                log.print_out(resp + b"\n")

                sentence = prompt(">>> ",
                                  history=FileHistory(
                                      os.path.join(self.config.model_dir,
                                                   ".chat_history"))).strip()
def chat(checkpoint, chat_logs_output_file, hparams, scope=None):
    # Containing the graph, model, source placeholder, batch_size placeholder and iterator

    if hparams.architecture == "simple":
        model_creator = SimpleModel
        get_infer_iterator = iterator_utils.get_infer_iterator

        def update_dialogue(so_far, utterance, response):
            # Dialogue so far does not matter
            return utterance

    elif hparams.architecture == "hier":
        model_creator = HierarchicalModel
        # Parse some of the arguments now
        get_infer_iterator = lambda dataset, vocab_table, batch_size, src_reverse, eos, src_max_len: \
            end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos,
                                                      src_max_len=src_max_len, eou=hparams.eou,
                                                      dialogue_max_len=hparams.dialogue_max_len)

        def update_dialogue(so_far, utterance, response):
            # Dialogue so far is considered
            if response == "":
                return utterance
            return so_far + " " + hparams.eou + " " + response + " " + hparams.eou + " " + utterance
    else:
        raise ValueError("Unkown architecture", hparams.architecture)

    infer_model = create_infer_model(model_creator,
                                     get_infer_iterator,
                                     hparams=hparams,
                                     verbose=False,
                                     scope=scope)

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        # Load the model from the checkpoint
        # ToDo: adapt for architectures
        loaded_infer_model = model_helper.load_model(model=infer_model.model,
                                                     ckpt=checkpoint,
                                                     session=sess,
                                                     name="infer",
                                                     verbose=False)
        utils.print_out("I'm Batman!")
        dialogue_so_far = ""
        response = ""
        # Gather all the user's utterances
        user_utterances = []
        # Leave it in chat mode until interrupted
        while True:
            # Read utterance from user.
            utterance, original_utterance = get_user_input(hparams)
            # Check if we should finish
            if utterance == 'end()':
                overall_score(user_utterances, hparams)
                break
            user_utterances.append(original_utterance)
            # Preprocess it into the familiar format for the machine
            utterance = preprocessing_utils.tokenize_line(
                utterance,
                number_token=hparams.number_token,
                name_token=hparams.name_token,
                gpe_token=hparams.gpe_token)
            dialogue_so_far = update_dialogue(dialogue_so_far, utterance,
                                              response)
            # Transform it into a batch of size 1
            batched_dialogue = [dialogue_so_far]
            # Initialize the iterator
            sess.run(infer_model.iterator.initializer,
                     feed_dict={
                         infer_model.src_placeholder: batched_dialogue,
                         infer_model.batch_size_placeholder: 1
                     })

            response = chatbot_utils.decode_utterance(
                model=loaded_infer_model,
                sess=sess,
                output_file=chat_logs_output_file,
                bpe_delimiter=hparams.bpe_delimiter,
                beam_width=hparams.beam_width,
                utterance=utterance,
                top_responses=hparams.top_responses,
                eos=hparams.eos,
                number_token=hparams.number_token,
                name_token=hparams.name_token)
def inference(checkpoint,
              inference_input_file,
              inference_output_file,
              hparams,
              scope=None):
    """Create the responses."""
    # TODO: can add multiple number of workers if so desired. This function is the helper _single_worker_inference
    # in the original model

    # Read the data
    infer_data = load_data(inference_input_file, hparams)

    if hparams.architecture == "simple":
        model_creator = SimpleModel
        get_infer_iterator = iterator_utils.get_infer_iterator

    elif hparams.architecture == "hier":
        model_creator = HierarchicalModel
        # Parse some of the arguments now
        get_infer_iterator = lambda dataset, vocab_table, batch_size, src_reverse, eos, src_max_len: \
            end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos,
                                                      src_max_len=src_max_len, eou=hparams.eou,
                                                      dialogue_max_len=hparams.dialogue_max_len)
    else:
        raise ValueError("Unkown architecture", hparams.architecture)

    # Containing the graph, model, source placeholder, batch_size placeholder and iterator
    infer_model = create_infer_model(model_creator, get_infer_iterator,
                                     hparams, scope)
    # ToDo: adapt for architectures
    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        # Load the model from the checkpoint
        loaded_infer_model = model_helper.load_model(model=infer_model.model,
                                                     ckpt=checkpoint,
                                                     session=sess,
                                                     name="infer")
        # Initialize the iterator
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        # Decode
        utils.print_out("# Starting Decoding")
        # Decode only a specific set of indices
        if hparams.inference_indices:
            _decode_inference_indices(
                model=loaded_infer_model,
                sess=sess,
                output_infer_file=inference_output_file,
                output_infer_summary_prefix=inference_output_file,
                inference_indices=hparams.inference_indices,
                eos=hparams.eos,
                bpe_delimiter=hparams.bpe_delimiter,
                number_token=hparams.number_token,
                name_token=hparams.name_token)
        else:
            chatbot_utils.decode_and_evaluate(
                name="infer",
                model=loaded_infer_model,
                sess=sess,
                output_file=inference_output_file,
                reference_file=None,
                metrics=hparams.metrics,
                bpe_delimiter=hparams.bpe_delimiter,
                beam_width=hparams.beam_width,
                eos=hparams.eos,
                number_token=hparams.number_token,
                name_token=hparams.name_token)
示例#5
0
    def test(self):
        start_test_time = time.time()

        assert self.config.n_responses >= 1

        if self.config.beam_width > 0:
            assert self.config.n_responses <= self.config.beam_width
        else:
            assert self.config.n_responses == 1

        self._pre_model_creation()

        infer_model = self._get_model_helper().create_infer_model(self.config)

        latest_ckpt = tf.train.latest_checkpoint(self.config.get_infer_model_dir())
        with tf.Session(
                config=model_helper.get_config_proto(self.config.log_device),
                graph=infer_model.graph) as infer_sess:
            loaded_infer_model = model_helper.load_model(
                infer_model.model, latest_ckpt, infer_sess, "infer")

            log.print_out("# Start decoding")
            log.print_out("  beam width: {}".format(self.config.beam_width))
            log.print_out("  length penalty: {}".format(self.config.length_penalty_weight))
            log.print_out("  sampling temperature: {}".format(self.config.sampling_temperature))
            log.print_out("  num responses per test instance: {}".format(self.config.n_responses))

            feed_dict = {
                infer_model.src_placeholder: self._load_data(self.config.test_data),
                infer_model.batch_size_placeholder: self.config.infer_batch_size,
            }

            if self.config.sampling_temperature > 0:
                label = "%s_t%.1f" % (
                    fs.file_name(self.config.test_data), self.config.sampling_temperature)
            else:
                label = "%s_bw%d_lp%.1f" % (
                    fs.file_name(self.config.test_data), self.config.beam_width, self.config.length_penalty_weight)

            self._decode_and_evaluate(loaded_infer_model, infer_sess, feed_dict,
                                      label=label,
                                      num_responses_per_input=self.config.n_responses)
        log.print_time("# Decoding done", start_test_time)

        eval_model = self._get_model_helper().create_eval_model(self.config)
        with tf.Session(
                config=model_helper.get_config_proto(self.config.log_device),
                graph=eval_model.graph) as eval_sess:
            loaded_eval_model = model_helper.load_model(
                eval_model.model, latest_ckpt, eval_sess, "eval")

            log.print_out("# Compute Perplexity")

            dev_eval_iterator_feed_dict = {
                eval_model.eval_file_placeholder: self.config.test_data
            }

            eval_sess.run(eval_model.iterator.initializer, feed_dict=dev_eval_iterator_feed_dict)
            model_helper.compute_perplexity(loaded_eval_model, eval_sess, "test")

        log.print_time("# Test finished", start_test_time)
示例#6
0
    def interactive(self, scope=None):
        import platform
        from prompt_toolkit import prompt
        from prompt_toolkit.history import FileHistory
        from topic_model.lda import TopicInferer, DEFAULT_SEPARATOR
        from util.nlp import NLPToolkit

        nlp_toolkit = NLPToolkit()

        __os = platform.system()

        self.config.infer_batch_size = 1
        self._pre_model_creation()

        if self.config.lda_model_dir is not None:
            topic_inferer = TopicInferer(self.config.lda_model_dir)
        else:
            topic_inferer = None

        infer_model = self._get_model_helper().create_infer_model(self.config, scope)

        with tf.Session(
                config=model_helper.get_config_proto(self.config.log_device), graph=infer_model.graph) as sess:
            latest_ckpt = tf.train.latest_checkpoint(self.config.model_dir)
            loaded_infer_model = model_helper.load_model(
                infer_model.model, latest_ckpt, sess, "infer")

            log.print_out("# Start decoding")

            if __os == 'Windows':
                sentence = input("> ").strip()
            else:
                sentence = prompt(">>> ",
                                  history=FileHistory(os.path.join(self.config.model_dir, ".chat_history"))).strip()
            conversation = [vocab.EOS] * (self.config.num_turns - 1)

            while sentence:
                current_utterance = ' '.join(nlp_toolkit.tokenize(sentence)).lower()
                conversation.append(current_utterance)
                conversation.pop(0)

                feedable_context = "\t".join(conversation)
                if topic_inferer is None:
                    iterator_feed_dict = {
                        infer_model.src_placeholder: [feedable_context],
                        infer_model.batch_size_placeholder: 1,
                    }
                else:
                    topic_words = topic_inferer.from_collection([feedable_context],
                                                                dialogue_as_doc=True,
                                                                words_per_topic=self.config.topic_words_per_utterance)
                    iterator_feed_dict = {
                        infer_model.src_placeholder:
                            [feedable_context + DEFAULT_SEPARATOR + " ".join(topic_words)],
                        infer_model.batch_size_placeholder: 1,
                    }

                sess.run(infer_model.iterator.initializer, feed_dict=iterator_feed_dict)
                output, infer_summary = loaded_infer_model.decode(sess)
                if self.config.beam_width > 0 and self._consider_beam():
                    # get the top translation.
                    output = output[0]

                resp = ncm_utils.get_translation(output, sent_id=0)

                log.print_out(resp + b"\n")

                if __os == 'Windows':
                    sentence = input("> ").strip()
                else:
                    sentence = prompt(">>> ",
                                      history=FileHistory(os.path.join(self.config.model_dir, ".chat_history"))).strip()

        print("Bye!!!")
示例#7
0
    def test(self):
        start_test_time = time.time()

        assert self.config.n_responses >= 1

        if self.config.beam_width > 0:
            assert self.config.n_responses <= self.config.beam_width
        else:
            assert self.config.n_responses == 1

        self._pre_model_creation()

        infer_model = taware_helper.create_infer_model(taware_model.TopicAwareSeq2SeqModel, self.config)

        config_proto = models.model_helper.get_config_proto(self.config.log_device)

        ckpt = tf.train.latest_checkpoint(self.config.get_infer_model_dir())
        with tf.Session(graph=infer_model.graph, config=config_proto) as infer_sess:
            loaded_infer_model = model_helper.load_model(
                infer_model.model, ckpt, infer_sess, "infer")

            log.print_out("# Start decoding")
            log.print_out("  beam width: {}".format(self.config.beam_width))
            log.print_out("  length penalty: {}".format(self.config.length_penalty_weight))
            log.print_out("  sampling temperature: {}".format(self.config.sampling_temperature))
            log.print_out("  num responses per test instance: {}".format(self.config.n_responses))

            feed_dict = {
                infer_model.src_placeholder: self._load_data(self.config.test_data),
                infer_model.batch_size_placeholder: self.config.infer_batch_size,
            }

            infer_sess.run(infer_model.iterator.initializer, feed_dict=feed_dict)

            if self.config.sampling_temperature > 0:
                label = "%s_t%.1f" % (
                    fs.file_name(self.config.test_data), self.config.sampling_temperature)
            else:
                label = "%s_bw%d_lp%.1f" % (
                    fs.file_name(self.config.test_data), self.config.beam_width, self.config.length_penalty_weight)

            out_file = os.path.join(self.config.model_dir, "output_{}".format(label))

            eval_metric.decode_and_evaluate(
                "test",
                loaded_infer_model,
                infer_sess,
                out_file,
                ref_file=None,
                metrics=self.config.metrics,
                beam_width=self.config.beam_width,
                num_translations_per_input=self.config.n_responses)
        log.print_time("# Decoding done", start_test_time)

        eval_model = taware_helper.create_eval_model(taware_model.TopicAwareSeq2SeqModel, self.config)
        with tf.Session(
                config=models.model_helper.get_config_proto(self.config.log_device), graph=eval_model.graph) as eval_sess:
            loaded_eval_model = model_helper.load_model(
                eval_model.model, ckpt, eval_sess, "eval")

            log.print_out("# Compute Perplexity")

            feed_dict = {
                eval_model.eval_file_placeholder: self.config.test_data
            }

            eval_sess.run(eval_model.iterator.initializer, feed_dict=feed_dict)

            model_helper.compute_perplexity(loaded_eval_model, eval_sess, "test")

        log.print_time("# Test finished", start_test_time)
示例#8
0
args['length_penalty_weight'] = 0.8
config = Config(**args)

stop_words = load_stop_words(config.model_dir)
word2label = load_keyword_label(config.model_dir)
label2topn = load_top_words(config.model_dir)

model = model_factory.create_model(config)

infer_model = model.create_infer_model_graph()
config_proto = model_helper.get_config_proto(config.log_device)

sess = tf.InteractiveSession(graph=infer_model.graph, config=config_proto)

ckpt = tf.train.latest_checkpoint(config.model_dir)
loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, sess, "infer")


def extract_keyword(query):
    tags = jieba.analyse.extract_tags(query, topK=10)
    keyword = []
    for w in tags:
        if re.search('^\d+$', w):
                continue
        if re.search('^[a-zA-Z]+$', w):
                continue
        if w in stop_words:
                continue
        keyword.append(w)
    if not keyword:
        return None, None