Ejemplo n.º 1
0
        def _shuffle_and_reopen(self):
            """ shuffle features & labels file. """
            if self._parent._shuffle_every_epoch:
                if not hasattr(self, "_shuffled_features_file"):
                    self._shuffled_features_file = self._features_file.strip().split("/")[-1] \
                                                   + "." + self._parent._shuffle_every_epoch
                    self._shuffled_labels_file = self._labels_file.strip().split("/")[-1] \
                                                 + "." + self._parent._shuffle_every_epoch

                tf.logging.info(
                    "shuffling data\n\t{} ==> {}\n\t{} ==> {}".format(
                        self._features_file, self._shuffled_features_file,
                        self._labels_file, self._shuffled_labels_file))
                shuffle_data(
                    [self._features_file, self._labels_file],
                    [self._shuffled_features_file, self._shuffled_labels_file])
                self._features_file = self._shuffled_features_file
                self._labels_file = self._shuffled_labels_file
                if hasattr(self, "_features"):
                    close_file(self._features)
                    close_file(self._labels)
            elif hasattr(self, "_features"):
                self._features.seek(0)
                self._labels.seek(0)
                return self._features, self._labels
            return open_file(self._features_file), open_file(self._labels_file)
Ejemplo n.º 2
0
    def _make_feeding_data_from(self, filename):
        """ Processes the data file and return an iterable instance for loop.

        Args:
            filename: A specific data file.

        Returns: An iterable instance that packs feeding dictionary
                   for `tf.Session().run` according to the `filename`.
        """
        features = open_file(filename, encoding="utf-8")
        str_buf = []
        ss_buf = []
        for ss in features:
            str_buf.append(self._vocab.bpe_encode(ss.strip()))
            ss_buf.append(self._vocab.convert_to_idlist(ss.strip().split(" ")))
        close_file(features)
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            x, len_x = padding_batch_data(
                ss_buf[batch_data_idx: batch_data_idx + self._batch_size],
                self._vocab.eos_id)
            str_x = str_buf[batch_data_idx: batch_data_idx + self._batch_size]
            batch_data_idx += self._batch_size
            data.append((
                str_x, len_x,
                {self.input_fields[GlobalNames.PH_FEATURE_IDS_NAME]: x,
                 self.input_fields[GlobalNames.PH_FEATURE_LENGTH_NAME]: len_x}))
        return data
Ejemplo n.º 3
0
    def _SmallParallelData(self,
                           features_file,
                           labels_file,
                           input_fields,
                           maximum_features_length=None,
                           maximum_labels_length=None):
        """ Function for reading small scale parallel data for evaluation.

        Args:
            features_file: The path of features file.
            labels_file: The path of labels file.
            input_fields: A dict of placeholders.
            maximum_features_length: The maximum length of feature symbols (especially
              after BPE is applied) . If provided, the number of symbols of one sentence
              exceeding this value will be ignore.
            maximum_labels_length: The maximum length of label symbols (especially
              after BPE is applied) . If provided, the number of symbols of one sentence
              exceeding this value will be ignore.

        Returns: A list of feeding data.
        """
        features = open_file(features_file, encoding="utf-8")
        labels = open_file(labels_file[0], encoding="utf-8")

        ss_buf = []
        tt_buf = []
        while True:
            ss = read_line_with_filter(features, maximum_features_length,
                                       self._features_preprocessing_fn)
            tt = read_line_with_filter(labels, maximum_labels_length,
                                       self._labels_preprocessing_fn)
            if ss == "" or tt == "":
                break
            ss_buf.append(ss)
            tt_buf.append(tt)
        close_file(features)
        close_file(labels)
        if self._bucketing:
            tt_buf, ss_buf = do_bucketing(tt_buf, [ss_buf])
            ss_buf = ss_buf[0]
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            data.append(
                pack_feed_dict(
                    name_prefixs=[
                        Constants.FEATURE_NAME_PREFIX,
                        Constants.LABEL_NAME_PREFIX
                    ],
                    origin_datas=[
                        ss_buf[batch_data_idx:batch_data_idx +
                               self._batch_size],
                        tt_buf[batch_data_idx:batch_data_idx +
                               self._batch_size]
                    ],
                    paddings=[self._features_padding, self._labels_padding],
                    input_fields=input_fields))
            batch_data_idx += self._batch_size
        return data
Ejemplo n.º 4
0
    def _SmallParallelData(self,
                           features_file,
                           labels_file,
                           maximum_features_length=None,
                           maximum_labels_length=None):
        """ Function for reading small scale parallel data for evaluation.

        Args:
            features_file: The path of features file.
            labels_file: The path of labels file.
            maximum_features_length: The maximum length of feature symbols (especially
              after BPE is applied) . If provided, the number of symbols of one sentence
              exceeding this value will be ignore.
            maximum_labels_length: The maximum length of label symbols (especially
              after BPE is applied) . If provided, the number of symbols of one sentence
              exceeding this value will be ignore.

        Returns: A list of feeding data.
        """
        features = open_file(features_file, encoding="utf-8")
        labels = open_file(labels_file[0], encoding="utf-8")

        ss_buf = []
        tt_buf = []
        while True:
            ss = read_line_with_filter(features, maximum_features_length,
                                       self._features_preprocessing_fn)
            tt = read_line_with_filter(labels, maximum_labels_length,
                                       self._labels_preprocessing_fn)
            if ss == "" or tt == "":
                break
            ss_buf.append(ss)
            tt_buf.append(tt)
        close_file(features)
        close_file(labels)
        if self._bucketing:
            tt_buf, ss_buf = do_bucketing(tt_buf, [ss_buf])
            ss_buf = ss_buf[0]
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            x, len_x = padding_batch_data(
                ss_buf[batch_data_idx:batch_data_idx + self._batch_size],
                self._features_padding)
            y, len_y = padding_batch_data(
                tt_buf[batch_data_idx:batch_data_idx + self._batch_size],
                self._labels_padding)
            data.append({
                "feature_ids": x,
                "label_ids": y,
                "feed_dict": {
                    self.input_fields[Constants.FEATURE_IDS_NAME]: x,
                    self.input_fields[Constants.FEATURE_LENGTH_NAME]: len_x,
                    self.input_fields[Constants.LABEL_IDS_NAME]: y,
                    self.input_fields[Constants.LABEL_LENGTH_NAME]: len_y
                }
            })
            batch_data_idx += self._batch_size
        return data
Ejemplo n.º 5
0
    def _EvalParallelData(self, features_file, labels_file):
        """ Function for reading small scale parallel data for evaluation.

        Args:
            features_file: The path of features file.
            labels_file: The path of labels file.

        Returns: A list of feeding data.
        """
        eval_features = open_file(features_file, encoding="utf-8")
        if gfile.Exists(labels_file):
            eval_labels = open_file(labels_file, encoding="utf-8")
        else:
            eval_labels = open_file(labels_file + "0", encoding="utf-8")
        ss_buf = []
        tt_buf = []
        ss_str_buf = []
        tt_str_buf = []
        for ss, tt in zip(eval_features, eval_labels):
            ss_str = self._vocab_source.bpe_encode(ss.strip()).split()
            tt_str = self._vocab_target.bpe_encode(tt.strip()).split()
            ss_str_buf.append(ss_str)
            tt_str_buf.append(tt_str)
            ss_buf.append(self._vocab_source.convert_to_idlist(ss.strip()))
            tt_buf.append(self._vocab_target.convert_to_idlist(tt.strip()))
        close_file(eval_features)
        close_file(eval_labels)
        if self._bucketing:
            tlen = numpy.array([len(t) for t in tt_buf])
            tidx = tlen.argsort()
            _ss_buf = [ss_buf[i] for i in tidx]
            _tt_buf = [tt_buf[i] for i in tidx]
            _ss_str_buf = [ss_str_buf[i] for i in tidx]
            _tt_str_buf = [tt_str_buf[i] for i in tidx]
            ss_buf = _ss_buf
            tt_buf = _tt_buf
            ss_str_buf = _ss_str_buf
            tt_str_buf = _tt_str_buf
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            x, len_x = padding_batch_data(
                ss_buf[batch_data_idx:batch_data_idx + self._batch_size],
                self._vocab_source.eos_id)
            y, len_y = padding_batch_data(
                tt_buf[batch_data_idx:batch_data_idx + self._batch_size],
                self._vocab_target.eos_id)
            data.append(
                (ss_str_buf[batch_data_idx:batch_data_idx + self._batch_size],
                 tt_str_buf[batch_data_idx:batch_data_idx + self._batch_size],
                 {
                     self.input_fields[GlobalNames.PH_FEATURE_IDS_NAME]: x,
                     self.input_fields[GlobalNames.PH_FEATURE_LENGTH_NAME]:
                     len_x,
                     self.input_fields[GlobalNames.PH_LABEL_IDS_NAME]: y,
                     self.input_fields[GlobalNames.PH_LABEL_LENGTH_NAME]: len_y
                 }))
            batch_data_idx += self._batch_size
        return data
Ejemplo n.º 6
0
 def _reset(self):
     if self._parent._shuffle_every_epoch:
         close_file(self._features)
         close_file(self._labels)
         self._shuffle()
         self._features = open_file(self._features_file, encoding="utf-8")
         self._labels = open_file(self._labels_file, encoding="utf-8")
     self._features.seek(0)
     self._labels.seek(0)
Ejemplo n.º 7
0
 def reset(self):
     if self.shuffle_every_epoch:
         close_file(self.source)
         close_file(self.target)
         tf.logging.info("shuffling data among epochs")
         shuffle_data([self.source_file, self.target_file],
                      ["./source.shuf." + self.shuffle_every_epoch,
                       "./target.shuf." + self.shuffle_every_epoch])
         self.source = open_file("./source.shuf." + self.shuffle_every_epoch)
         self.target = open_file("./target.shuf." + self.shuffle_every_epoch)
     else:
         self.source.seek(0)
         self.target.seek(0)
Ejemplo n.º 8
0
 def reset(self):
     if self.shuffle_every_epoch:
         close_file(self.source)
         close_file(self.target)
         tf.logging.info("shuffling data among epochs")
         shuffle_data([self.source_file, self.target_file], [
             "./source.shuf." + self.shuffle_every_epoch,
             "./target.shuf." + self.shuffle_every_epoch
         ])
         self.source = open_file("./source.shuf." +
                                 self.shuffle_every_epoch)
         self.target = open_file("./target.shuf." +
                                 self.shuffle_every_epoch)
     else:
         self.source.seek(0)
         self.target.seek(0)
Ejemplo n.º 9
0
    def _make_feeding_data_from(self,
                                filename,
                                maximum_line_length=None,
                                maximum_encoded_length=None):
        """ Processes the data file and return an iterable instance for loop.

        Args:
            filename: A specific data file.
            maximum_line_length: The maximum sequence length. If provided,
              sentences exceeding this value will be ignore.
            maximum_encoded_length: The maximum length of symbols (especially
              after BPE is applied). If provided symbols of one sentence exceeding
              this value will be ignore.

        Returns: An iterable instance that packs feeding dictionary
                   for `tf.Session().run` according to the `filename`.
        """
        features = open_file(filename, encoding="utf-8")
        str_buf = []
        ss_buf = []
        for ss in features:
            if maximum_line_length and len(
                    ss.strip().split()) > maximum_line_length:
                continue
            encoded_ss = self._vocab.convert_to_idlist(ss.strip().split())
            if maximum_encoded_length and len(
                    encoded_ss) - 1 > maximum_encoded_length:
                continue
            bpe_ss = self._vocab.bpe_encode(ss.strip())
            str_buf.append(bpe_ss)
            ss_buf.append(encoded_ss)
        close_file(features)
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            x, len_x = padding_batch_data(
                ss_buf[batch_data_idx:batch_data_idx + self._batch_size],
                self._vocab.eos_id)
            str_x = str_buf[batch_data_idx:batch_data_idx + self._batch_size]
            batch_data_idx += self._batch_size
            data.append((str_x, len_x, {
                self.input_fields[GlobalNames.PH_FEATURE_IDS_NAME]:
                x,
                self.input_fields[GlobalNames.PH_FEATURE_LENGTH_NAME]:
                len_x
            }))
        return data
Ejemplo n.º 10
0
    def _make_feeding_data_from(self, filename, maximum_length=None):
        """ Processes the data file and return an iterable instance for loop.

        Args:
            filename: A specific data file.
            maximum_length: The maximum length of symbols (especially
              after BPE is applied). If provided symbols of one sentence exceeding
              this value will be ignore.

        Returns: An iterable instance that packs feeding dictionary
                   for `tf.Session().run` according to the `filename`.
        """
        features = open_file(filename, encoding="utf-8")
        ss_buf = []
        encoded_ss = read_line_with_filter(features, maximum_length,
                                           self._preprocessing_fn)
        while encoded_ss != "":
            ss_buf.append(encoded_ss)
            encoded_ss = read_line_with_filter(features, maximum_length,
                                               self._preprocessing_fn)
        close_file(features)
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            x, len_x = padding_batch_data(
                ss_buf[batch_data_idx:batch_data_idx + self._batch_size],
                self._padding)
            batch_data_idx += self._batch_size
            if "features" in self._data_field_name:
                data.append({
                    "feature_ids": x,
                    "feed_dict": {
                        self.input_fields[Constants.FEATURE_IDS_NAME]: x,
                        self.input_fields[Constants.FEATURE_LENGTH_NAME]: len_x
                    }
                })
            else:
                data.append({
                    "label_ids": x,
                    "feed_dict": {
                        self.input_fields[Constants.LABEL_IDS_NAME]: x,
                        self.input_fields[Constants.LABEL_LENGTH_NAME]: len_x
                    }
                })
        return data
Ejemplo n.º 11
0
    def _make_feeding_data_from(self,
                                filename,
                                input_fields,
                                maximum_length=None):
        """ Processes the data file and return an iterable instance for loop.

        Args:
            filename: A specific data file.
            input_fields: A dict of placeholders.
            maximum_length: The maximum length of symbols (especially
              after BPE is applied). If provided symbols of one sentence exceeding
              this value will be ignore.

        Returns: An iterable instance that packs feeding dictionary
                   for `tf.Session().run` according to the `filename`.
        """
        features = open_file(filename, encoding="utf-8")
        ss_buf = []
        encoded_ss = read_line_with_filter(features, maximum_length,
                                           self._preprocessing_fn)
        while encoded_ss != "":
            ss_buf.append(encoded_ss)
            encoded_ss = read_line_with_filter(features, maximum_length,
                                               self._preprocessing_fn)
        close_file(features)
        data = []
        batch_data_idx = 0
        name_prefix = Constants.FEATURE_NAME_PREFIX \
            if "features" in self._data_field_name else Constants.LABEL_NAME_PREFIX

        while batch_data_idx < len(ss_buf):
            data.append(
                pack_feed_dict(
                    name_prefixs=name_prefix,
                    origin_datas=ss_buf[batch_data_idx:batch_data_idx +
                                        self._batch_size],
                    paddings=self._padding,
                    input_fields=input_fields))
            batch_data_idx += self._batch_size
        return data
Ejemplo n.º 12
0
    def reset(self,
              do_shuffle=False,
              shuffle_to_file=None,
              argsort_index=None):
        """ Resets this reader and shuffle (if needed).

        Args:
            do_shuffle: Whether to shuffle data.
            shuffle_to_file: A string.
            argsort_index: A list of integers

        Returns: The `argsort_index` if do shuffling.
        """
        # TODO
        self._data_index = 0
        if self._filename is not None:
            self._data.seek(0)
        if do_shuffle:
            if self._filename is None:  # list of data
                _ = shuffle_to_file
                if not argsort_index:
                    argsort_index = numpy.arange(len(self._data))
                    numpy.random.shuffle(argsort_index)
                self._data = self._data[argsort_index]  # do shuffle
            else:  # from file
                assert shuffle_to_file, ("`shuffle_to_file` must be provided.")
                tf.logging.info("shuffling data:\t{} ==> {}".format(
                    self._filename, shuffle_to_file))
                data_list = self._data.readlines()
                close_file(self._data)
                if argsort_index is None:
                    argsort_index = numpy.arange(len(data_list))
                    numpy.random.shuffle(argsort_index)
                with open_file(shuffle_to_file, "utf-8", "w") as fw:
                    for idx in argsort_index:
                        fw.write(data_list[idx].strip() + "\n")
                del data_list[:]
                self._data = open_file(shuffle_to_file, "utf-8", "r")
        return argsort_index
Ejemplo n.º 13
0
    def _SmallParallelData(self,
                           features_file,
                           labels_file,
                           maximum_features_length=None,
                           maximum_labels_length=None,
                           maximum_encoded_features_length=None,
                           maximum_encoded_labels_length=None):
        """ Function for reading small scale parallel data.

        Args:
            features_file: The path of features file.
            labels_file: The path of labels file.
            maximum_features_length: The maximum sequence length of "features" field.
              If provided, sentences exceeding this value will be ignore.
            maximum_labels_length: The maximum sequence length of "labels" field.
              If provided, sentences exceeding this value will be ignore.
            maximum_encoded_features_length: The maximum length of feature symbols (especially
              after BPE is applied) . If provided, the number of symbols of one sentence
              exceeding this value will be ignore.
            maximum_encoded_labels_length: The maximum length of label symbols (especially
              after BPE is applied) . If provided, the number of symbols of one sentence
              exceeding this value will be ignore.

        Returns: A list of feeding data.
        """
        eval_features = open_file(features_file, encoding="utf-8")
        if gfile.Exists(labels_file):
            eval_labels = open_file(labels_file, encoding="utf-8")
        else:
            eval_labels = open_file(labels_file + "0", encoding="utf-8")
        ss_buf = []
        tt_buf = []
        for ss, tt in zip(eval_features, eval_labels):
            if maximum_features_length and len(ss.strip().split()) > maximum_features_length:
                continue
            if maximum_labels_length and len(tt.strip().split()) > maximum_labels_length:
                continue
            encoded_ss = self._vocab_source.convert_to_idlist(ss.strip().split(" "))
            if maximum_encoded_features_length and len(encoded_ss) - 1 > maximum_encoded_features_length:
                continue
            encoded_tt = self._vocab_target.convert_to_idlist(tt.strip().split(" "))
            if maximum_encoded_labels_length and len(encoded_tt) - 1 > maximum_encoded_labels_length:
                continue
            ss_buf.append(encoded_ss)
            tt_buf.append(encoded_tt)
        close_file(eval_features)
        close_file(eval_labels)
        if self._bucketing:
            tlen = numpy.array([len(t) for t in tt_buf])
            tidx = tlen.argsort()
            _ss_buf = [ss_buf[i] for i in tidx]
            _tt_buf = [tt_buf[i] for i in tidx]
            ss_buf = _ss_buf
            tt_buf = _tt_buf
        data = []
        batch_data_idx = 0
        while batch_data_idx < len(ss_buf):
            x, len_x = padding_batch_data(
                ss_buf[batch_data_idx: batch_data_idx + self._batch_size],
                self._vocab_source.eos_id)
            y, len_y = padding_batch_data(
                tt_buf[batch_data_idx: batch_data_idx + self._batch_size],
                self._vocab_target.eos_id)
            batch_data_idx += self._batch_size
            data.append((len(len_x), {
                self.input_fields[GlobalNames.PH_FEATURE_IDS_NAME]: x,
                self.input_fields[GlobalNames.PH_FEATURE_LENGTH_NAME]: len_x,
                self.input_fields[GlobalNames.PH_LABEL_IDS_NAME]: y,
                self.input_fields[GlobalNames.PH_LABEL_LENGTH_NAME]: len_y}))
        return data
Ejemplo n.º 14
0
 def close(self):
     """ Closes this reader.  """
     self._data_index = 0
     if self._filename is not None:
         close_file(self._data)