def get_dataset(directory, file_type, num_parallel_reads=1, shuffle=True): """Get a dataset as a tf.data.Dataset. Input can be a bucket or a local file :param directory: Either a bucket or a file :param file_type: Currently supports "json" files or "tfrecords" :param num_parallel_reads: The number of parallel reads :param shuffle: Defaults to True :return: a `tf.data.Dataset` """ pattern = os.path.join(directory, f'*.{file_type}') files = tf.io.gfile.glob(pattern) logger.debug(files) if file_type in ['json', 'jsonl']: ds = tf.data.TextLineDataset(files, num_parallel_reads=num_parallel_reads) if shuffle: ds = ds.shuffle(100) ds = ds.map(decode_json) return ds if not shuffle: ds = tf.data.TFRecordDataset(files, num_parallel_reads=num_parallel_reads) else: ds = tf.data.Dataset.from_tensor_slices(tf.constant(files)) ds = ds.shuffle(buffer_size=len(files)) ds = ds.interleave(lambda x: tf.data.TFRecordDataset(x), num_parallel_calls=tf.data.experimental.AUTOTUNE, cycle_length=num_parallel_reads) ds = ds.shuffle(buffer_size=100) ds = ds.map(_parse_tf_record) return ds
def _parse_json(example): j = json.loads(example.numpy()) return tf.constant(j['x'], dtype=tf.int32), tf.constant(j['y'], dtype=tf.int32)