def build(self, auto_shard=False, map_func=None, map_output_dtypes=None,
              shuffle=True) -> tf.data.Dataset:
        """ Reads data from files and build the tf dataset.

        Args:
            auto_shard: Whether to automatically shard the dataset.
            map_func: A function mapping a dataset element to another dataset element.
            map_output_dtypes: A list/tuple of dtypes after applying `map_func`.
            shuffle: Whether to shuffle the TF records files.

        Returns: A tf.data.Dataset.
        """
        if not tf.io.gfile.exists(self._path):
            raise ValueError(f"Fail to find data path: {self._path}.")
        lang2numpy_iter = dict()
        for path in tf.io.gfile.glob(os.path.join(self._path, "*")):
            lang_pair = path.strip().split("/")[-1]
            langs = lang_pair.strip().split("2")
            if len(langs) == 1:
                langs = lang_pair.strip().split("-")
                reversed_lang_pair = langs[1] + "-" + langs[0]
            else:
                reversed_lang_pair = langs[1] + "2" + langs[0]
            if tf.io.gfile.isdir(path) and len(langs) == 2:
                lang2numpy_iter[lang_pair] = load_tfrecords(
                    file_path=os.path.join(path, "*"),
                    shuffle=True, deterministic=False,
                    auto_shard=False, map_func=map_func,
                    name_to_features={"feature": tf.io.VarLenFeature(tf.int64),
                                      "label": tf.io.VarLenFeature(tf.int64)},
                    auxiliary_elements={"src_lang": langs[0], "trg_lang": langs[1]}).repeat().as_numpy_iterator()
                if self._auto_switch_langs:
                    lang2numpy_iter[reversed_lang_pair] = load_tfrecords(
                        file_path=os.path.join(path, "*"),
                        shuffle=True, deterministic=False,
                        auto_shard=False, map_func=map_func,
                        name_to_features={"feature": tf.io.VarLenFeature(tf.int64),
                                          "label": tf.io.VarLenFeature(tf.int64)},
                        feature_name_mapping={"label": "feature", "feature": "label"},
                        auxiliary_elements={"trg_lang": langs[0], "src_lang": langs[1]}).repeat().as_numpy_iterator()
            else:
                logging.info(f"Ignore {path}.")

        def gen():
            keys = list(lang2numpy_iter.keys())
            while True:
                if self._data_sampler is None:
                    choice = random.choice(keys)
                else:
                    choice = self._data_sampler()
                yield lang2numpy_iter[choice].next()

        # TODO: to see https://www.tensorflow.org/tutorials/structured_data/imbalanced_data
        # and https://www.tensorflow.org/api_docs/python/tf/data/experimental/sample_from_datasets
        dataset = tf.data.Dataset.from_generator(gen, output_types=map_output_dtypes)
        options = tf.data.Options()
        options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
        dataset = dataset.with_options(options)
        return dataset
Ejemplo n.º 2
0
 def gen():
     ds = load_tfrecords(self._data_path,
                         shuffle=False,
                         auto_shard=False,
                         name_to_features=self.fields,
                         sharding_index=shard_id,
                         num_shards=total_shards)
     for x in ds:
         data = to_numpy_or_python_type(x, bytes_as_str=True)
         if map_func is not None:
             data = map_func(data)
         yield data
Ejemplo n.º 3
0
 def gen():
     for data in load_tfrecords(self._data_path,
                                shuffle=self._shuffle_dataset
                                and shuffle,
                                deterministic=(not shuffle),
                                auto_shard=auto_shard,
                                name_to_features=self.fields,
                                feature_name_mapping={
                                    self._feature_key: "audio",
                                    self._transcript_key:
                                    "transcript"
                                }):
         data = to_numpy_or_python_type(data, bytes_as_str=True)
         if map_func is not None:
             data = map_func(data)
         yield data
Ejemplo n.º 4
0
 def gen():
     ds = load_tfrecords(self._data_path,
                         shuffle=False,
                         auto_shard=False,
                         name_to_features=self.fields,
                         sharding_index=shard_id,
                         num_shards=total_shards,
                         feature_name_mapping={
                             self._feature_key: "audio",
                             self._transcript_key: "transcript",
                             self._translation_key: "translation"
                         })
     for x in ds:
         data = to_numpy_or_python_type(x, bytes_as_str=True)
         if map_func is not None:
             data = map_func(data)
         yield data
Ejemplo n.º 5
0
    def build(self,
              auto_shard=False,
              map_func=None,
              map_output_dtypes=None,
              shuffle=True) -> tf.data.Dataset:

        try:
            return load_tfrecords(self._data_path,
                                  shuffle=self._shuffle_dataset and shuffle,
                                  deterministic=(not shuffle),
                                  map_func=lambda x: x
                                  if map_func is None else map_func(x),
                                  auto_shard=auto_shard,
                                  name_to_features=self.fields,
                                  feature_name_mapping={
                                      self._feature_key: "audio",
                                      self._transcript_key: "transcript",
                                      self._translation_key: "translation"
                                  })
        except AttributeError:

            logging.info(
                "Call Dataset.from_generator for AudioTripleTFRecordDataset")

            def gen():
                for data in load_tfrecords(self._data_path,
                                           shuffle=self._shuffle_dataset
                                           and shuffle,
                                           deterministic=(not shuffle),
                                           auto_shard=auto_shard,
                                           name_to_features=self.fields,
                                           feature_name_mapping={
                                               self._feature_key: "audio",
                                               self._transcript_key:
                                               "transcript",
                                               self._translation_key:
                                               "translation"
                                           }):
                    data = to_numpy_or_python_type(data, bytes_as_str=True)
                    if map_func is not None:
                        data = map_func(data)
                    yield data

            return tf.data.Dataset.from_generator(
                gen, output_types=map_output_dtypes)
Ejemplo n.º 6
0
    def build(self,
              auto_shard=False,
              map_func=None,
              map_output_dtypes=None,
              shuffle=True) -> tf.data.Dataset:
        """ Reads data from files and build the tf dataset.

        Args:
            auto_shard: Whether to automatically shard the dataset.
            map_func: A function mapping a dataset element to another dataset element.
            map_output_dtypes: A list/tuple of dtypes after applying `map_func`.
            shuffle: Whether to shuffle the TF records files.

        Returns: A tf.data.Dataset.
        """
        _ = map_output_dtypes
        return load_tfrecords(self._data_path,
                              shuffle=self._shuffle_dataset and shuffle,
                              deterministic=(not shuffle),
                              auto_shard=auto_shard,
                              map_func=map_func,
                              name_to_features=self.fields)