示例#1
0
    def _FillInputQueue(self):
        """逐行填充输入队列"""
        pad_id = self._vocab.WordToId(parameter_config.PAD_TOKEN)
        if self._hps.mode == 'train':
            input_gen = self._TextGenerator(
                data.ExampleGen(os.path.join(self._data_path, '*')))
        else:
            input_gen = self._TextGenerator(
                data.ExampleGen(os.path.join(self._data_path, '*'), 1))

        while True:
            try:
                (index_id, target, sentence) = input_gen.next()
            except (GeneratorExit, StopIteration):
                break

            enc_inputs = data.GetWordIds(sentence.strip(), self._vocab)
            target = int(target)

            # Filter out too-short input
            if (len(enc_inputs) < self._hps.min_input_len):
                # tf.logging.warning('Drop an example - too short.\nenc:%d\ndec:%d',
                #                   len(enc_inputs), len(dec_inputs))
                continue

            # If we're not truncating input, throw out too-long input
            if not self._truncate_input:
                if (len(enc_inputs) > self._hps.enc_timesteps):
                    # tf.logging.warning('Drop an example - too long.\nenc:%d\ndec:%d',
                    #                  len(enc_inputs), len(dec_inputs))
                    continue
            # If we are truncating input, do so if necessary
            else:
                if len(enc_inputs) > self._hps.enc_timesteps:
                    enc_inputs = enc_inputs[:self._hps.enc_timesteps]

            enc_input_len = len(enc_inputs)

            # Pad if necessary
            while len(enc_inputs) < self._hps.enc_timesteps:
                enc_inputs.append(pad_id)

            element = ModelInput(index_id, target, enc_inputs, enc_input_len)
            self._input_queue.put(element)
    def _FillInputQueue(self):
        """Fill input queue with ModelInput."""
        start_id = self._vocab.WordToId(data.SENTENCE_START)
        end_id = self._vocab.WordToId(data.SENTENCE_END)
        pad_id = self._vocab.WordToId(data.PAD_TOKEN)
        input_gen = self._TextGenerator(data.ExampleGen(self._data_path))
        while True:
            (article, abstract) = input_gen.next()
            article_sentences = [
                sent.strip()
                for sent in data.ToSentences(article, include_token=False)
            ]
            abstract_sentences = [
                sent.strip()
                for sent in data.ToSentences(abstract, include_token=False)
            ]

            enc_inputs = []
            # Use the <s> as the <GO> symbol for decoder inputs.
            dec_inputs = [start_id]

            # Convert first N sentences to word IDs, stripping existing <s> and </s>.
            for i in xrange(
                    min(self._max_article_sentences, len(article_sentences))):
                enc_inputs += data.GetWordIds(article_sentences[i],
                                              self._vocab)
            for i in xrange(
                    min(self._max_abstract_sentences,
                        len(abstract_sentences))):
                dec_inputs += data.GetWordIds(abstract_sentences[i],
                                              self._vocab)

            # Filter out too-short input
            if (len(enc_inputs) < self._hps.min_input_len
                    or len(dec_inputs) < self._hps.min_input_len):
                tf.logging.warning(
                    'Drop an example - too short.\nenc:%d\ndec:%d',
                    len(enc_inputs), len(dec_inputs))
                continue

            # If we're not truncating input, throw out too-long input
            if not self._truncate_input:
                if (len(enc_inputs) > self._hps.enc_timesteps
                        or len(dec_inputs) > self._hps.dec_timesteps):
                    tf.logging.warning(
                        'Drop an example - too long.\nenc:%d\ndec:%d',
                        len(enc_inputs), len(dec_inputs))
                    continue
            # If we are truncating input, do so if necessary
            else:
                if len(enc_inputs) > self._hps.enc_timesteps:
                    enc_inputs = enc_inputs[:self._hps.enc_timesteps]
                if len(dec_inputs) > self._hps.dec_timesteps:
                    dec_inputs = dec_inputs[:self._hps.dec_timesteps]

            # targets is dec_inputs without <s> at beginning, plus </s> at end
            targets = dec_inputs[1:]
            targets.append(end_id)

            # Now len(enc_inputs) should be <= enc_timesteps, and
            # len(targets) = len(dec_inputs) should be <= dec_timesteps

            enc_input_len = len(enc_inputs)
            dec_output_len = len(targets)

            # Pad if necessary
            while len(enc_inputs) < self._hps.enc_timesteps:
                enc_inputs.append(pad_id)
            while len(dec_inputs) < self._hps.dec_timesteps:
                dec_inputs.append(end_id)
            while len(targets) < self._hps.dec_timesteps:
                targets.append(end_id)

            element = ModelInput(enc_inputs, dec_inputs, targets,
                                 enc_input_len, dec_output_len,
                                 ' '.join(article_sentences),
                                 ' '.join(abstract_sentences))
            self._input_queue.put(element)
示例#3
0
    def _FillInputQueue(self):
        """Fills input queue with ModelInput."""

        # input gets padded
        pad_id = self._input_vocab.WordToId(data.PAD_TOKEN)
        # output get start id and padded with end ids
        end_id = self._output_vocab.WordToId(data.SENTENCE_END)

        input_gen = self._TextGenerator(data.ExampleGen(self._data_path))
        while True:
            (source, targets) = next(input_gen)
            # target = choice(targets)
            target = targets[0]

            # Convert sentences to word IDs, stripping existing <s> and </s>.
            enc_inputs = data.GetWordIds(source, self._input_vocab)
            dec_inputs_gen = data.GetWordIds(target, self._output_vocab)
            dec_inputs_cop = data.GetWordIndices(target,
                                                 source,
                                                 self._input_vocab,
                                                 position_based_indexing=True)

            # Filter out too-short input
            if len(enc_inputs) < self._config.min_input_len:
                tf.logging.warning(
                    'Drop an example - input to short: %d (min: %d)',
                    len(enc_inputs), self._config.min_input_len)
                continue

            if len(dec_inputs_gen) < self._config.min_input_len:
                tf.logging.warning(
                    'Drop an example - output to short: %d (min: %d)',
                    len(enc_inputs), self._config.min_input_len)
                continue

            # If we're not truncating input, throw out too-long input
            if not self._truncate_input:
                if len(enc_inputs) > self._config.max_input_len:
                    tf.logging.warning(
                        'Drop an example - input to long: %d (max: %d)',
                        len(enc_inputs), self._config.max_input_len)
                    continue
                if len(dec_inputs_gen) > self._config.max_output_len:
                    tf.logging.warning(
                        'Drop an example - output to long: %d (max: %d)',
                        len(dec_inputs_gen), self._config.max_output_len)
                    continue
            # If we are truncating input, do so if necessary
            else:
                if len(enc_inputs) > self._config.max_input_len:
                    enc_inputs = enc_inputs[:self._config.max_input_len]
                    dec_inputs_cop = [
                        pos if pos <= self._config.max_input_len else 0
                        for pos in dec_inputs_cop
                    ]
                if len(dec_inputs_gen) > self._config.max_output_len:
                    dec_inputs_gen = dec_inputs_gen[:self._config.
                                                    max_output_len]
                    dec_inputs_cop = dec_inputs_cop[:self._config.
                                                    max_output_len]

            # dec_targets_gen is dec_inputs without <s> at beginning, plus </s> at end
            dec_targets_gen = dec_inputs_gen[1:]
            dec_targets_gen.append(end_id)

            # dec_targets_gen is dec_inputs without <s> at beginning, plus </s> at end
            dec_targets_cop = dec_inputs_cop[1:]
            end_position = len(enc_inputs)
            dec_targets_cop.append(end_position)

            enc_input_len = len(enc_inputs)
            dec_output_len = len(
                dec_targets_gen)  # is equal to len(dec_targets_cop)

            # Pad if necessary
            while len(enc_inputs) < self._config.max_input_len:
                enc_inputs.append(pad_id)
            while len(dec_inputs_gen) < self._config.max_output_len:
                dec_inputs_gen.append(end_id)
            while len(dec_targets_gen) < self._config.max_output_len:
                dec_targets_gen.append(end_id)
            while len(dec_targets_cop) < self._config.max_output_len:
                dec_targets_cop.append(end_position)

            element = ModelInput(enc_inputs, dec_inputs_gen, dec_targets_gen,
                                 dec_targets_cop, enc_input_len,
                                 dec_output_len, source, targets)
            self._input_queue.put(element)
示例#4
0
    def _FillInputQueue(self):
        """Fill input queue with ModelInput.
    SENTENCE_START = '<s>'
    SENTENCE_END = '</s>'
    UNKNOWN_TOKEN = '<UNK>'
    PAD_TOKEN = '<PAD>'
    """
        start_id = self._vocab.WordToId(data.SENTENCE_START)
        end_id = self._vocab.WordToId(data.SENTENCE_END)
        pad_id = self._vocab.WordToId(data.PAD_TOKEN)
        input_gen = self._TextGenerator(data.ExampleGen(self._data_path))
        while True:
            (article, abstract) = six.next(input_gen)
            #得到一个个句子,每个句子开头以<s>开始,以</s>结束,当include_token为False时,将开始和结尾的<s>,</s>去掉了
            article_sentences = [
                sent.strip()
                for sent in data.ToSentences(article, include_token=False)
            ]
            abstract_sentences = [
                sent.strip()
                for sent in data.ToSentences(abstract, include_token=False)
            ]

            enc_inputs = []
            # Use the <s> as the <GO> symbol for decoder inputs.
            #在解码模块的输入最前方加上<s>
            dec_inputs = [start_id]

            # Convert first N sentences to word IDs, stripping existing <s> and </s>.
            for i in xrange(
                    min(self._max_article_sentences, len(article_sentences))):
                #将一句话变为一个向量
                enc_inputs += data.GetWordIds(article_sentences[i],
                                              self._vocab)
            for i in xrange(
                    min(self._max_abstract_sentences,
                        len(abstract_sentences))):
                dec_inputs += data.GetWordIds(abstract_sentences[i],
                                              self._vocab)

            # Filter out too-short input
            #句子长度太短
            if (len(enc_inputs) < self._hps.min_input_len
                    or len(dec_inputs) < self._hps.min_input_len):
                tf.logging.warning(
                    'Drop an example - too short.\nenc:%d\ndec:%d',
                    len(enc_inputs), len(dec_inputs))
                continue

            #句子太长
            if not self._truncate_input:
                if (len(enc_inputs) > self._hps.enc_timesteps
                        or len(dec_inputs) > self._hps.dec_timesteps):
                    tf.logging.warning(
                        'Drop an example - too long.\nenc:%d\ndec:%d',
                        len(enc_inputs), len(dec_inputs))
                    continue
            # If we are truncating input, do so if necessary
            else:
                if len(enc_inputs) > self._hps.enc_timesteps:
                    enc_inputs = enc_inputs[:self._hps.enc_timesteps]
                if len(dec_inputs) > self._hps.dec_timesteps:
                    dec_inputs = dec_inputs[:self._hps.dec_timesteps]

            # targets is dec_inputs without <s> at beginning, plus </s> at end
            #解码阶段的输入是dec_inputs,以<s>开始,目标targets以</s结束>
            targets = dec_inputs[1:]
            targets.append(end_id)

            # Now len(enc_inputs) should be <= enc_timesteps, and
            # len(targets) = len(dec_inputs) should be <= dec_timesteps

            enc_input_len = len(enc_inputs)
            dec_output_len = len(targets)

            # 如果比指定长度短,在此处填充,dec_inputs是[<s>,...],targets是[...,<\s>]
            while len(enc_inputs) < self._hps.enc_timesteps:
                enc_inputs.append(pad_id)  #<PAD>,enc_inputs不包含<s>,</s>
            while len(dec_inputs) < self._hps.dec_timesteps:
                dec_inputs.append(end_id)
            while len(targets) < self._hps.dec_timesteps:
                targets.append(end_id)
            #将nametupe放入队列之中
            #参数:enc_inputs是编码阶段的输入,dec_inputs是解码阶段的输入,targets是解码的输出目标
            element = ModelInput(enc_inputs, dec_inputs, targets,
                                 enc_input_len, dec_output_len,
                                 ' '.join(article_sentences),
                                 ' '.join(abstract_sentences))
            self._input_queue.put(element)