예제 #1
0
    def _EvalParallelData(self, features_file, labels_file):
        """ Function for reading small scale parallel data for evaluation.

        Args:
            features_file: The path of features file.
            labels_file: The path of labels file.

        Returns: A list of feeding data.
        """
        eval_features = open_file(features_file, encoding="utf-8")
        if gfile.Exists(labels_file):
            eval_labels = open_file(labels_file, encoding="utf-8")
        else:
            eval_labels = open_file(labels_file + "0", encoding="utf-8")
        ss_buf = []
        tt_buf = []
        ss_str_buf = []
        tt_str_buf = []
        for ss, tt in zip(eval_features, eval_labels):
            ss_str = self._vocab_source.bpe_encode(ss.strip()).split()
            tt_str = self._vocab_target.bpe_encode(tt.strip()).split()
            ss_str_buf.append(ss_str)
            tt_str_buf.append(tt_str)
            ss_buf.append(self._vocab_source.convert_to_idlist(ss.strip()))
            tt_buf.append(self._vocab_target.convert_to_idlist(tt.strip()))
        close_file(eval_features)
        close_file(eval_labels)
        if self._bucketing:
            tlen = numpy.array([len(t) for t in tt_buf])
            tidx = tlen.argsort()
            _ss_buf = [ss_buf[i] for i in tidx]
            _tt_buf = [tt_buf[i] for i in tidx]
            _ss_str_buf = [ss_str_buf[i] for i in tidx]
            _tt_str_buf = [tt_str_buf[i] for i in tidx]
            ss_buf = _ss_buf
            tt_buf = _tt_buf
            ss_str_buf = _ss_str_buf
            tt_str_buf = _tt_str_buf
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            x, len_x = padding_batch_data(
                ss_buf[batch_data_idx:batch_data_idx + self._batch_size],
                self._vocab_source.eos_id)
            y, len_y = padding_batch_data(
                tt_buf[batch_data_idx:batch_data_idx + self._batch_size],
                self._vocab_target.eos_id)
            data.append(
                (ss_str_buf[batch_data_idx:batch_data_idx + self._batch_size],
                 tt_str_buf[batch_data_idx:batch_data_idx + self._batch_size],
                 {
                     self.input_fields[GlobalNames.PH_FEATURE_IDS_NAME]: x,
                     self.input_fields[GlobalNames.PH_FEATURE_LENGTH_NAME]:
                     len_x,
                     self.input_fields[GlobalNames.PH_LABEL_IDS_NAME]: y,
                     self.input_fields[GlobalNames.PH_LABEL_LENGTH_NAME]: len_y
                 }))
            batch_data_idx += self._batch_size
        return data
예제 #2
0
    def _SmallParallelData(self,
                           features_file,
                           labels_file,
                           maximum_features_length=None,
                           maximum_labels_length=None):
        """ Function for reading small scale parallel data for evaluation.

        Args:
            features_file: The path of features file.
            labels_file: The path of labels file.
            maximum_features_length: The maximum length of feature symbols (especially
              after BPE is applied) . If provided, the number of symbols of one sentence
              exceeding this value will be ignore.
            maximum_labels_length: The maximum length of label symbols (especially
              after BPE is applied) . If provided, the number of symbols of one sentence
              exceeding this value will be ignore.

        Returns: A list of feeding data.
        """
        features = open_file(features_file, encoding="utf-8")
        labels = open_file(labels_file[0], encoding="utf-8")

        ss_buf = []
        tt_buf = []
        while True:
            ss = read_line_with_filter(features, maximum_features_length,
                                       self._features_preprocessing_fn)
            tt = read_line_with_filter(labels, maximum_labels_length,
                                       self._labels_preprocessing_fn)
            if ss == "" or tt == "":
                break
            ss_buf.append(ss)
            tt_buf.append(tt)
        close_file(features)
        close_file(labels)
        if self._bucketing:
            tt_buf, ss_buf = do_bucketing(tt_buf, [ss_buf])
            ss_buf = ss_buf[0]
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            x, len_x = padding_batch_data(
                ss_buf[batch_data_idx:batch_data_idx + self._batch_size],
                self._features_padding)
            y, len_y = padding_batch_data(
                tt_buf[batch_data_idx:batch_data_idx + self._batch_size],
                self._labels_padding)
            data.append({
                "feature_ids": x,
                "label_ids": y,
                "feed_dict": {
                    self.input_fields[Constants.FEATURE_IDS_NAME]: x,
                    self.input_fields[Constants.FEATURE_LENGTH_NAME]: len_x,
                    self.input_fields[Constants.LABEL_IDS_NAME]: y,
                    self.input_fields[Constants.LABEL_LENGTH_NAME]: len_y
                }
            })
            batch_data_idx += self._batch_size
        return data
예제 #3
0
 def _make_inputs(self, features, labels):
     x, len_x = padding_batch_data(features, self._parent._vocab_source.eos_id)
     y, len_y = padding_batch_data(labels, self._parent._vocab_target.eos_id)
     return {
         self._parent.input_fields[GlobalNames.PH_FEATURE_IDS_NAME]: x,
         self._parent.input_fields[GlobalNames.PH_FEATURE_LENGTH_NAME]: len_x,
         self._parent.input_fields[GlobalNames.PH_LABEL_IDS_NAME]: y,
         self._parent.input_fields[GlobalNames.PH_LABEL_LENGTH_NAME]: len_y}
예제 #4
0
 def _make_inputs(self, features, labels):
     x, len_x = padding_batch_data(features,
                                   self._parent._features_padding)
     y, len_y = padding_batch_data(labels, self._parent._labels_padding)
     return {
         self._parent.input_fields[Constants.FEATURE_IDS_NAME]: x,
         self._parent.input_fields[Constants.FEATURE_LENGTH_NAME]:
         len_x,
         self._parent.input_fields[Constants.LABEL_IDS_NAME]: y,
         self._parent.input_fields[Constants.LABEL_LENGTH_NAME]: len_y
     }
예제 #5
0
    def __init__(self, source, vocab_source, batch_size=1, n_words_src=-1):
        # read in batch datas
        f_source = open_file(source)

        ss_buf = []
        ss_str_buf = []
        for ss in f_source:
            # ss_str_buf.append(ss.strip())
            ss_str_buf.append(vocab_source.bpe_encode(ss.strip()))
            ss = vocab_source.convert_to_idlist(ss.strip().split(),
                                                n_words_src)
            ss_buf.append(ss)
        f_source.close()

        self.batch_source_buffer = []
        self.batch_source_str_buffer = []

        self.batch_data_idx = 0
        self.batch_size = batch_size
        while self.batch_data_idx < len(ss_buf):
            self.batch_source_buffer.append(
                padding_batch_data(
                    ss_buf[self.batch_data_idx:self.batch_data_idx +
                           batch_size], vocab_source.eos_id))
            self.batch_source_str_buffer.append(
                ss_str_buf[self.batch_data_idx:self.batch_data_idx +
                           batch_size])
            self.batch_data_idx += batch_size
        self.reset()
예제 #6
0
    def _make_feeding_data_from(self, filename):
        """ Processes the data file and return an iterable instance for loop.

        Args:
            filename: A specific data file.

        Returns: An iterable instance that packs feeding dictionary
                   for `tf.Session().run` according to the `filename`.
        """
        features = open_file(filename, encoding="utf-8")
        str_buf = []
        ss_buf = []
        for ss in features:
            str_buf.append(self._vocab.bpe_encode(ss.strip()))
            ss_buf.append(self._vocab.convert_to_idlist(ss.strip().split(" ")))
        close_file(features)
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            x, len_x = padding_batch_data(
                ss_buf[batch_data_idx: batch_data_idx + self._batch_size],
                self._vocab.eos_id)
            str_x = str_buf[batch_data_idx: batch_data_idx + self._batch_size]
            batch_data_idx += self._batch_size
            data.append((
                str_x, len_x,
                {self.input_fields[GlobalNames.PH_FEATURE_IDS_NAME]: x,
                 self.input_fields[GlobalNames.PH_FEATURE_LENGTH_NAME]: len_x}))
        return data
예제 #7
0
    def __init__(self, source,
                 vocab_source,
                 batch_size=1,
                 n_words_src=-1):
        # read in batch datas
        f_source = open_file(source)

        ss_buf = []
        ss_str_buf = []
        for ss in f_source:
            # ss_str_buf.append(ss.strip())
            ss_str_buf.append(vocab_source.bpe_encode(ss.strip()))
            ss = vocab_source.convert_to_idlist(ss.strip().split(), n_words_src)
            ss_buf.append(ss)
        f_source.close()

        self.batch_source_buffer = []
        self.batch_source_str_buffer = []

        self.batch_data_idx = 0
        self.batch_size = batch_size
        while self.batch_data_idx < len(ss_buf):
            self.batch_source_buffer.append(
                padding_batch_data(ss_buf[self.batch_data_idx: self.batch_data_idx + batch_size], vocab_source.eos_id))
            self.batch_source_str_buffer.append(
                ss_str_buf[self.batch_data_idx: self.batch_data_idx + batch_size])
            self.batch_data_idx += batch_size
        self.reset()
예제 #8
0
    def __init__(self,
                 source,
                 target,
                 vocab_source,
                 vocab_target,
                 batch_size=128,
                 n_words_src=-1,
                 n_words_trg=-1):
        # read in batch datas
        f_source = open_file(source)
        if gfile.Exists(target):
            f_target = open_file(target)
        else:
            f_target = open_file(target + "0")

        ss_buf = []
        tt_buf = []
        for ss, tt in zip(f_source, f_target):
            ss = vocab_source.convert_to_idlist(ss.strip().split(),
                                                n_words_src)
            tt = vocab_target.convert_to_idlist(tt.strip().split(),
                                                n_words_trg)
            ss_buf.append(ss)
            tt_buf.append(tt)
        f_source.close()
        f_target.close()
        tlen = numpy.array([len(t) for t in tt_buf])
        tidx = tlen.argsort()
        _ss_buf = [ss_buf[i] for i in tidx]
        _tt_buf = [tt_buf[i] for i in tidx]
        ss_buf = _ss_buf
        tt_buf = _tt_buf
        self.batch_source_buffer = []
        self.batch_target_buffer = []
        self.batch_data_idx = 0
        self.batch_size = batch_size
        while self.batch_data_idx < len(ss_buf):
            self.batch_source_buffer.append(
                padding_batch_data(
                    ss_buf[self.batch_data_idx:self.batch_data_idx +
                           batch_size], vocab_source.eos_id))
            self.batch_target_buffer.append(
                padding_batch_data(
                    tt_buf[self.batch_data_idx:self.batch_data_idx +
                           batch_size], vocab_target.eos_id))
            self.batch_data_idx += batch_size
        self.reset()
예제 #9
0
 def _feed_batchs(_start_idx, _inpf):
     if _start_idx * n_samples_per_gpu >= n_samples:
         return 0
     x, x_len = padding_batch_data(
         d[_start_idx * n_samples_per_gpu:(_start_idx + 1) * n_samples_per_gpu], p)
     data["feed_dict"][_inpf[concat_name(n, Constants.IDS_NAME)]] = x
     data["feed_dict"][_inpf[concat_name(n, Constants.LENGTH_NAME)]] = x_len
     return len(x_len)
예제 #10
0
def evaluate_sentences(sources,
                       targets,
                       sess,
                       input_fields,
                       eval_op,
                       vocab_source,
                       vocab_target,
                       n_words_src=-1,
                       n_words_trg=-1):
    """ Evaluates a list of sentences.

    Args:
        sources: A list of strings.
        targets: A list of strings.
        sess: `tf.Session`.
        input_fields: The dictionary of placeholders.
        eval_op: Tensorflow operation.
        vocab_source: A `Vocab` instance for source side feature map.
        vocab_target: A `Vocab` instance for target side feature map.
        n_words_src: An integer number. If provided and > 0, source side
          token id that exceed this value will be mapped into UNK id.
        n_words_trg: An integer number. If provided and > 0, target side
          token id that exceed this value will be mapped into UNK id.

    Returns: Results of `eval_op`.
    """
    sources = [
        vocab_source.convert_to_idlist(re.split(r"\s*", snt.strip()),
                                       n_words_src) for snt in sources
    ]
    targets = [
        vocab_target.convert_to_idlist(re.split(r"\s*", snt.strip()),
                                       n_words_trg) for snt in targets
    ]
    ph_x = input_fields[Constants.FEATURE_IDS_NAME]
    ph_x_len = input_fields[Constants.FEATURE_LENGTH_NAME]
    ph_y = input_fields[Constants.LABEL_IDS_NAME]
    ph_y_len = input_fields[Constants.LABEL_LENGTH_NAME]
    x, len_x = padding_batch_data(sources, vocab_source.pad_id)
    y, len_y = padding_batch_data(targets, vocab_target.pad_id)
    feed_dict = {ph_x: x, ph_x_len: len_x, ph_y: y, ph_y_len: len_y}
    return _evaluate(sess, feed_dict, eval_op)
예제 #11
0
    def __init__(self, source, target,
                 vocab_source, vocab_target,
                 batch_size=128,
                 n_words_src=-1,
                 n_words_trg=-1):
        # read in batch datas
        f_source = open_file(source)
        if gfile.Exists(target):
            f_target = open_file(target)
        else:
            f_target = open_file(target + "0")

        ss_buf = []
        tt_buf = []
        for ss, tt in zip(f_source, f_target):
            ss = vocab_source.convert_to_idlist(ss.strip().split(), n_words_src)
            tt = vocab_target.convert_to_idlist(tt.strip().split(), n_words_trg)
            ss_buf.append(ss)
            tt_buf.append(tt)
        f_source.close()
        f_target.close()
        tlen = numpy.array([len(t) for t in tt_buf])
        tidx = tlen.argsort()
        _ss_buf = [ss_buf[i] for i in tidx]
        _tt_buf = [tt_buf[i] for i in tidx]
        ss_buf = _ss_buf
        tt_buf = _tt_buf
        self.batch_source_buffer = []
        self.batch_target_buffer = []
        self.batch_data_idx = 0
        self.batch_size = batch_size
        while self.batch_data_idx < len(ss_buf):
            self.batch_source_buffer.append(
                padding_batch_data(ss_buf[self.batch_data_idx: self.batch_data_idx + batch_size], vocab_source.eos_id))
            self.batch_target_buffer.append(
                padding_batch_data(tt_buf[self.batch_data_idx: self.batch_data_idx + batch_size], vocab_target.eos_id))
            self.batch_data_idx += batch_size
        self.reset()
예제 #12
0
    def _make_feeding_data_from(self,
                                filename,
                                maximum_line_length=None,
                                maximum_encoded_length=None):
        """ Processes the data file and return an iterable instance for loop.

        Args:
            filename: A specific data file.
            maximum_line_length: The maximum sequence length. If provided,
              sentences exceeding this value will be ignore.
            maximum_encoded_length: The maximum length of symbols (especially
              after BPE is applied). If provided symbols of one sentence exceeding
              this value will be ignore.

        Returns: An iterable instance that packs feeding dictionary
                   for `tf.Session().run` according to the `filename`.
        """
        features = open_file(filename, encoding="utf-8")
        str_buf = []
        ss_buf = []
        for ss in features:
            if maximum_line_length and len(
                    ss.strip().split()) > maximum_line_length:
                continue
            encoded_ss = self._vocab.convert_to_idlist(ss.strip().split())
            if maximum_encoded_length and len(
                    encoded_ss) - 1 > maximum_encoded_length:
                continue
            bpe_ss = self._vocab.bpe_encode(ss.strip())
            str_buf.append(bpe_ss)
            ss_buf.append(encoded_ss)
        close_file(features)
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            x, len_x = padding_batch_data(
                ss_buf[batch_data_idx:batch_data_idx + self._batch_size],
                self._vocab.eos_id)
            str_x = str_buf[batch_data_idx:batch_data_idx + self._batch_size]
            batch_data_idx += self._batch_size
            data.append((str_x, len_x, {
                self.input_fields[GlobalNames.PH_FEATURE_IDS_NAME]:
                x,
                self.input_fields[GlobalNames.PH_FEATURE_LENGTH_NAME]:
                len_x
            }))
        return data
예제 #13
0
    def _make_feeding_data_from(self, filename, maximum_length=None):
        """ Processes the data file and return an iterable instance for loop.

        Args:
            filename: A specific data file.
            maximum_length: The maximum length of symbols (especially
              after BPE is applied). If provided symbols of one sentence exceeding
              this value will be ignore.

        Returns: An iterable instance that packs feeding dictionary
                   for `tf.Session().run` according to the `filename`.
        """
        features = open_file(filename, encoding="utf-8")
        ss_buf = []
        encoded_ss = read_line_with_filter(features, maximum_length,
                                           self._preprocessing_fn)
        while encoded_ss != "":
            ss_buf.append(encoded_ss)
            encoded_ss = read_line_with_filter(features, maximum_length,
                                               self._preprocessing_fn)
        close_file(features)
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            x, len_x = padding_batch_data(
                ss_buf[batch_data_idx:batch_data_idx + self._batch_size],
                self._padding)
            batch_data_idx += self._batch_size
            if "features" in self._data_field_name:
                data.append({
                    "feature_ids": x,
                    "feed_dict": {
                        self.input_fields[Constants.FEATURE_IDS_NAME]: x,
                        self.input_fields[Constants.FEATURE_LENGTH_NAME]: len_x
                    }
                })
            else:
                data.append({
                    "label_ids": x,
                    "feed_dict": {
                        self.input_fields[Constants.LABEL_IDS_NAME]: x,
                        self.input_fields[Constants.LABEL_LENGTH_NAME]: len_x
                    }
                })
        return data
예제 #14
0
파일: decode.py 프로젝트: njufyz/NJUNMT-tf
def infer_sentences(sources,
                    sess,
                    input_fields,
                    prediction_op,
                    vocab_source,
                    alpha=None,
                    top_k=1,
                    n_words_src=-1):
    """ Infers a list of sentences.

    Args:
        sources: A list of strings.
        sess: `tf.Session`.
        input_fields: The dictionary of placeholders.
        prediction_op: Tensorflow operation for inference.
        vocab_source: A `Vocab` instance for source side feature map.
        alpha: A scalar number, length penalty rate. If not provided
          or < 0, simply average each beam by length of predicted
          sequence.
        top_k: An integer, number of predicted sequences will be
          returned.
        n_words_src: An integer number. If provided and > 0, source side
          token id that exceed this value will be mapped into UNK id.

    Returns: A tuple `(predicted_sequences, attention_scores)`.
      The `predicted_sequences` is an ndarray of shape
      [`top_k`, max_sequence_length].
      The `attention_scores` is None if there is no attention
      related information in `prediction_op`.
    """
    sources = [
        vocab_source.convert_to_idlist(re.split(r"\s*", snt.strip()),
                                       n_words_src) for snt in sources
    ]
    ph_x = input_fields[GlobalNames.PH_FEATURE_IDS_NAME]
    ph_x_len = input_fields[GlobalNames.PH_FEATURE_LENGTH_NAME]
    x, len_x = padding_batch_data(sources, vocab_source.eos_id)
    feed_dict = {ph_x: x, ph_x_len: len_x}
    return _infer(sess, feed_dict, prediction_op, len(sources), alpha, top_k)
예제 #15
0
    def next(self):
        if self.end_of_data:
            self.end_of_data = False
            self.reset()
            raise StopIteration

        source = []
        target = []

        assert len(self.source_buffer) == len(self.target_buffer), 'Buffer size mismatch'

        if len(self.source_buffer) == 0:
            cnt = 0
            while cnt < self.k:
                ss = self.source.readline()
                if ss == "":
                    break
                tt = self.target.readline()
                if tt == "":
                    break

                ss = ss.strip().split()
                tt = tt.strip().split()
                if len(ss) > self.maxlen_src or len(tt) > self.maxlen_trg:
                    continue

                cnt += 1
                self.source_buffer.append(ss)
                self.target_buffer.append(tt)

            # sort by target buffer
            tlen = numpy.array([len(t) for t in self.target_buffer])
            tidx = tlen.argsort()

            _sbuf = [self.source_buffer[i] for i in tidx]
            _tbuf = [self.target_buffer[i] for i in tidx]

            self.source_buffer = _sbuf
            self.target_buffer = _tbuf

            if len(self.source_buffer) == 0 or len(self.target_buffer) == 0:
                self.end_of_data = False
                self.reset()
                raise StopIteration

        try:
            while True:
                # read source
                try:
                    ss = self.source_buffer.pop(0)
                except IndexError:
                    break
                ss = self.vocab_source.convert_to_idlist(ss, self.n_words_src)

                # read target
                tt = self.target_buffer.pop(0)
                tt = self.vocab_target.convert_to_idlist(tt, self.n_words_trg)

                source.append(ss)
                target.append(tt)

                if len(source) >= self.batch_size or \
                                len(target) >= self.batch_size:
                    break
        except IOError:
            self.end_of_data = True

        if len(source) <= 0 or len(target) <= 0:
            self.end_of_data = False
            self.reset()
            raise StopIteration

        return padding_batch_data(source, self.vocab_source.eos_id), \
               padding_batch_data(target, self.vocab_target.eos_id)
예제 #16
0
    def _SmallParallelData(self,
                           features_file,
                           labels_file,
                           maximum_features_length=None,
                           maximum_labels_length=None,
                           maximum_encoded_features_length=None,
                           maximum_encoded_labels_length=None):
        """ Function for reading small scale parallel data.

        Args:
            features_file: The path of features file.
            labels_file: The path of labels file.
            maximum_features_length: The maximum sequence length of "features" field.
              If provided, sentences exceeding this value will be ignore.
            maximum_labels_length: The maximum sequence length of "labels" field.
              If provided, sentences exceeding this value will be ignore.
            maximum_encoded_features_length: The maximum length of feature symbols (especially
              after BPE is applied) . If provided, the number of symbols of one sentence
              exceeding this value will be ignore.
            maximum_encoded_labels_length: The maximum length of label symbols (especially
              after BPE is applied) . If provided, the number of symbols of one sentence
              exceeding this value will be ignore.

        Returns: A list of feeding data.
        """
        eval_features = open_file(features_file, encoding="utf-8")
        if gfile.Exists(labels_file):
            eval_labels = open_file(labels_file, encoding="utf-8")
        else:
            eval_labels = open_file(labels_file + "0", encoding="utf-8")
        ss_buf = []
        tt_buf = []
        for ss, tt in zip(eval_features, eval_labels):
            if maximum_features_length and len(ss.strip().split()) > maximum_features_length:
                continue
            if maximum_labels_length and len(tt.strip().split()) > maximum_labels_length:
                continue
            encoded_ss = self._vocab_source.convert_to_idlist(ss.strip().split(" "))
            if maximum_encoded_features_length and len(encoded_ss) - 1 > maximum_encoded_features_length:
                continue
            encoded_tt = self._vocab_target.convert_to_idlist(tt.strip().split(" "))
            if maximum_encoded_labels_length and len(encoded_tt) - 1 > maximum_encoded_labels_length:
                continue
            ss_buf.append(encoded_ss)
            tt_buf.append(encoded_tt)
        close_file(eval_features)
        close_file(eval_labels)
        if self._bucketing:
            tlen = numpy.array([len(t) for t in tt_buf])
            tidx = tlen.argsort()
            _ss_buf = [ss_buf[i] for i in tidx]
            _tt_buf = [tt_buf[i] for i in tidx]
            ss_buf = _ss_buf
            tt_buf = _tt_buf
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            x, len_x = padding_batch_data(
                ss_buf[batch_data_idx: batch_data_idx + self._batch_size],
                self._vocab_source.eos_id)
            y, len_y = padding_batch_data(
                tt_buf[batch_data_idx: batch_data_idx + self._batch_size],
                self._vocab_target.eos_id)
            batch_data_idx += self._batch_size
            data.append((len(len_x), {
                self.input_fields[GlobalNames.PH_FEATURE_IDS_NAME]: x,
                self.input_fields[GlobalNames.PH_FEATURE_LENGTH_NAME]: len_x,
                self.input_fields[GlobalNames.PH_LABEL_IDS_NAME]: y,
                self.input_fields[GlobalNames.PH_LABEL_LENGTH_NAME]: len_y}))
        return data
예제 #17
0
 def map_fn(n, d, p):
     x, x_len = padding_batch_data(d, p)
     data[concat_name(n, Constants.IDS_NAME)] = d
     data["feed_dict"][input_fields[concat_name(n, Constants.IDS_NAME)]] = x
     data["feed_dict"][input_fields[concat_name(
         n, Constants.LENGTH_NAME)]] = x_len
예제 #18
0
    def next(self):
        if self.end_of_data:
            self.end_of_data = False
            self.reset()
            raise StopIteration

        source = []
        target = []

        assert len(self.source_buffer) == len(
            self.target_buffer), 'Buffer size mismatch'

        if len(self.source_buffer) == 0:
            cnt = 0
            while cnt < self.k:
                ss = self.source.readline()
                if ss == "":
                    break
                tt = self.target.readline()
                if tt == "":
                    break

                ss = ss.strip().split()
                tt = tt.strip().split()
                if len(ss) > self.maxlen_src or len(tt) > self.maxlen_trg:
                    continue

                cnt += 1
                self.source_buffer.append(ss)
                self.target_buffer.append(tt)

            # sort by target buffer
            tlen = numpy.array([len(t) for t in self.target_buffer])
            tidx = tlen.argsort()

            _sbuf = [self.source_buffer[i] for i in tidx]
            _tbuf = [self.target_buffer[i] for i in tidx]

            self.source_buffer = _sbuf
            self.target_buffer = _tbuf

            if len(self.source_buffer) == 0 or len(self.target_buffer) == 0:
                self.end_of_data = False
                self.reset()
                raise StopIteration

        try:
            while True:
                # read source
                try:
                    ss = self.source_buffer.pop(0)
                except IndexError:
                    break
                ss = self.vocab_source.convert_to_idlist(ss, self.n_words_src)

                # read target
                tt = self.target_buffer.pop(0)
                tt = self.vocab_target.convert_to_idlist(tt, self.n_words_trg)

                source.append(ss)
                target.append(tt)

                if len(source) >= self.batch_size or \
                                len(target) >= self.batch_size:
                    break
        except IOError:
            self.end_of_data = True

        if len(source) <= 0 or len(target) <= 0:
            self.end_of_data = False
            self.reset()
            raise StopIteration

        return padding_batch_data(source, self.vocab_source.eos_id), \
               padding_batch_data(target, self.vocab_target.eos_id)