def _read_txt_dataset(self, data_files, shuffle=False, repeat=True, **kwargs): log.info('reading raw files from %s' % '\n'.join(data_files)) dataset = Dataset.from_list(data_files) if repeat: dataset = dataset.repeat() if shuffle: dataset = dataset.shuffle(buffer_size=len(data_files)) fn = partial(_interleave_func, map_fn=lambda filename: Dataset.from_file(filename), cycle_length=len(data_files), block_length=1) dataset = dataset.apply(fn) if shuffle: dataset = dataset.shuffle(buffer_size=1000) def _parse_txt_file( record_str): # function that takes python_str as input features = record_str.strip(b'\n').split(b'\t') ret = [ column.raw_to_instance(feature) for feature, column in zip(features, self._columns) ] return ret dataset = dataset.map(_parse_txt_file) return dataset
def _read_stdin_dataset(self, encoding='utf8', shuffle=False, **kwargs): log.info('reading raw files stdin') def _gen(): if six.PY3: source = sys.stdin.buffer else: source = sys.stdin while True: line = source.readline() if len(line) == 0: break yield line, dataset = Dataset.from_generator_func(_gen) if shuffle: dataset = dataset.shuffle(buffer_size=1000) def _parse_stdin(record_str): """function that takes python_str as input""" features = record_str.strip(b'\n').split(b'\t') ret = [ column.raw_to_instance(feature) for feature, column in zip(features, self._columns) ] return ret dataset = dataset.map(_parse_stdin) return dataset
def _read_gz_dataset(self, gz_files, shuffle=False, repeat=True, shard=False, **kwargs): if len(gz_files) == 0: raise ValueError('reading gz from empty file list: %s' % gz_files) log.info('reading gz from %s' % '\n'.join(gz_files)) dataset = Dataset.from_list(gz_files) if repeat: dataset = dataset.repeat() if shard and distribution.status.mode == distribution.DistributionMode.NCCL: log.info('Apply dataset sharding in distribution env') train_ds = train_ds.shard(distribution.status.num_replica, distribution.status.replica_id) if shuffle: dataset = dataset.shuffle(buffer_size=len(gz_files)) fn = partial( _interleave_func, map_fn=lambda filename: Dataset.from_record_file(filename), cycle_length=len(gz_files), block_length=1) dataset = dataset.apply(fn) if shuffle: dataset = dataset.shuffle(buffer_size=1000) def _parse_gz(record_str): # function that takes python_str as input ex = example_pb2.Example() ex.ParseFromString(record_str) ret = [] fea_dict = ex.features.feature for c in self._columns: ins = c.proto_to_instance(fea_dict[c.name]) ret.append(ins) return ret dataset = dataset.map(_parse_gz) return dataset