Esempio n. 1
0
    def _read_txt_dataset(self,
                          data_files,
                          shuffle=False,
                          repeat=True,
                          **kwargs):
        log.info('reading raw files from %s' % '\n'.join(data_files))
        dataset = Dataset.from_list(data_files)
        if repeat:
            dataset = dataset.repeat()
        if shuffle:
            dataset = dataset.shuffle(buffer_size=len(data_files))

        fn = partial(_interleave_func,
                     map_fn=lambda filename: Dataset.from_file(filename),
                     cycle_length=len(data_files),
                     block_length=1)
        dataset = dataset.apply(fn)
        if shuffle:
            dataset = dataset.shuffle(buffer_size=1000)

        def _parse_txt_file(
                record_str):  # function that takes python_str as input
            features = record_str.strip(b'\n').split(b'\t')
            ret = [
                column.raw_to_instance(feature)
                for feature, column in zip(features, self._columns)
            ]
            return ret

        dataset = dataset.map(_parse_txt_file)
        return dataset
Esempio n. 2
0
    def _read_gz_dataset(self,
                         gz_files,
                         shuffle=False,
                         repeat=True,
                         shard=False,
                         **kwargs):
        if len(gz_files) == 0:
            raise ValueError('reading gz from empty file list: %s' % gz_files)
        log.info('reading gz from %s' % '\n'.join(gz_files))
        dataset = Dataset.from_list(gz_files)
        if repeat:
            dataset = dataset.repeat()

        if shard and distribution.status.mode == distribution.DistributionMode.NCCL:
            log.info('Apply dataset sharding in distribution env')
            train_ds = train_ds.shard(distribution.status.num_replica,
                                      distribution.status.replica_id)

        if shuffle:
            dataset = dataset.shuffle(buffer_size=len(gz_files))
        fn = partial(
            _interleave_func,
            map_fn=lambda filename: Dataset.from_record_file(filename),
            cycle_length=len(gz_files),
            block_length=1)
        dataset = dataset.apply(fn)
        if shuffle:
            dataset = dataset.shuffle(buffer_size=1000)

        def _parse_gz(record_str):  # function that takes python_str as input
            ex = example_pb2.Example()
            ex.ParseFromString(record_str)
            ret = []
            fea_dict = ex.features.feature
            for c in self._columns:
                ins = c.proto_to_instance(fea_dict[c.name])
                ret.append(ins)
            return ret

        dataset = dataset.map(_parse_gz)
        return dataset