예제 #1
0
def _build_instruction_ds(instructions):
    """Create a dataset containing individual instruction for each shard.

  Each instruction is a dict:
  ```
  {
      "filepath": tf.Tensor(shape=(), dtype=tf.string),
      "mask_offset": tf.Tensor(shape=(), dtype=tf.int64),
      "mask": tf.Tensor(shape=(100,), dtype=tf.bool),
  }
  ```

  Args:
    instructions: `list[dict]`, the list of instruction dict

  Returns:
    instruction_ds: The dataset containing the instruction. The dataset size is
      the number of shard.
  """
    # Transpose the list[dict] into dict[list]
    tensor_inputs = {
        # offset_mask need to be converted to int64 explicitly
        k: np.array(vals, dtype=np.int64) if k == "mask_offset" else list(vals)
        for k, vals in utils.zip_dict(*instructions)
    }
    return tf.data.Dataset.from_tensor_slices(tensor_inputs)
예제 #2
0
 def encode_example(self, example_dict):
     """See base class for details."""
     return {
         k: feature.encode_example(example_value)
         for k, (feature, example_value
                 ) in utils.zip_dict(self._feature_dict, example_dict)
     }
예제 #3
0
    def encode_example(self, example_dict):
        # Convert nested dict[list] into list[nested dict]
        sequence_elements = _transpose_dict_list(example_dict)

        # If length is static, ensure that the given length match
        if self._length is not None and len(sequence_elements) != self._length:
            raise ValueError(
                'Input sequence length do not match the defined one. Got {} != '
                '{}'.format(len(sequence_elements), self._length))

        # Empty sequences return empty arrays
        if not sequence_elements:
            return {
                key: np.empty(shape=(0, ),
                              dtype=serialized_info.dtype.as_numpy_dtype)
                for key, serialized_info in self.get_serialized_info().items()
            }

        # Encode each individual elements
        sequence_elements = [
            super(SequenceDict, self).encode_example(sequence_elem)
            for sequence_elem in sequence_elements
        ]

        # Then merge the elements back together
        sequence_elements = {
            # Stack along the first dimension
            k: stack_arrays(*elems)
            for k, elems in utils.zip_dict(*sequence_elements)
        }
        return sequence_elements
예제 #4
0
  def encode_sample(self, sample_dict):
    """See base class for details."""
    # Flatten dict matching the tf-example specs
    # Use NonMutableDict to ensure there is no collision between features keys
    tfexample_dict = utils.NonMutableDict()

    # Iterate over sample fields
    for feature_key, (feature, sample_value) in utils.zip_dict(
        self._feature_dict,
        sample_dict
    ):
      # Encode the field with the associated encoder
      encoded_feature = feature.encode_sample(sample_value)

      # Singleton case
      if not feature.specs_keys:
        tfexample_dict[feature_key] = encoded_feature
      # Feature contains sub features
      else:
        _assert_keys_match(encoded_feature.keys(), feature.specs_keys)
        tfexample_dict.update({
            posixpath.join(feature_key, k): encoded_feature[k]
            for k in feature.specs_keys
        })
    return tfexample_dict
예제 #5
0
    def parse_example(self, serialized_example):
        """Deserialize a single `tf.train.Example` proto.

    Usage:
    ```
    ds = tf.data.TFRecordDataset(filepath)
    ds = ds.map(file_adapter.parse_example)
    ```

    Args:
      serialized_example: `tf.Tensor`, the `tf.string` tensor containing the
        serialized proto to decode.

    Returns:
      example: A nested `dict` of `tf.Tensor` values. The structure and tensors
        shape/dtype match the  `example_specs` provided at construction.
    """
        example = tf.io.parse_single_example(
            serialized=serialized_example,
            features=self._build_feature_specs(),
        )
        example = {
            k: _deserialize_single_field(example_data, tensor_info)
            for k, (example_data, tensor_info
                    ) in utils.zip_dict(example, self._flat_example_specs)
        }
        # Reconstruct all nesting
        example = utils.pack_as_nest_dict(example, self._example_specs)
        return example
예제 #6
0
def _dict_to_tf_example(example_dict, tensor_info_dict=None):
    """Builds tf.train.Example from (string -> int/float/str list) dictionary.

  Args:
    example_dict: `dict`, dict of values, tensor,...
    tensor_info_dict: `dict` of `tfds.feature.TensorInfo` If given, perform
      additional checks on the example dict (check dtype, shape, number of
      fields...)
  """
    def serialize_single_field(k, example_data, tensor_info):
        with utils.try_reraise(
                "Error while serializing feature {} ({}): ".format(
                    k, tensor_info)):
            return _item_to_tf_feature(example_data, tensor_info)

    if tensor_info_dict:
        example_dict = {
            k: serialize_single_field(k, example_data, tensor_info)
            for k, (
                example_data,
                tensor_info) in utils.zip_dict(example_dict, tensor_info_dict)
        }
    else:
        example_dict = {
            k: serialize_single_field(k, example_data, None)
            for k, example_data in example_dict.items()
        }

    return tf.train.Example(features=tf.train.Features(feature=example_dict))
예제 #7
0
def _read_single_instruction(
    instruction,
    parse_fn, name, path, name2len, name2shard_lengths, shuffle_files):
  """Returns tf.data.Dataset for given instruction.

  Args:
    instruction (ReadInstruction or str): if str, a ReadInstruction will be
      constructed using `ReadInstruction.from_spec(str)`.
    parse_fn (callable): function used to parse each record.
    name (str): name of the dataset.
    path (str): path to directory where to read tfrecords from.
    name2len: dict associating split names to number of examples.
    name2shard_lengths: dict associating split names to shard lengths.
    shuffle_files (bool): Defaults to False. True to shuffle input files.
  """
  if not isinstance(instruction, ReadInstruction):
    instruction = ReadInstruction.from_spec(instruction)
  absolute_instructions = instruction.to_absolute(name2len)
  files = list(itertools.chain.from_iterable([
      _get_dataset_files(name, path, abs_instr, name2shard_lengths)
      for abs_instr in absolute_instructions]))
  if not files:
    msg = 'Instruction "%s" corresponds to no data!' % instruction
    raise AssertionError(msg)

  do_skip = any(f['skip'] > 0 for f in files)
  do_take = any(f['take'] > -1 for f in files)

  # Transpose the list[dict] into dict[list]
  tensor_inputs = {
      # skip/take need to be converted to int64 explicitly
      k: list(vals) if k == 'filename' else np.array(vals, dtype=np.int64)
      for k, vals in utils.zip_dict(*files)
  }

  # Both parallel_reads and block_length have empirically been tested to give
  # good results on imagenet.
  # This values might be changes in the future, with more performance test runs.
  parallel_reads = 16
  block_length = 16

  instruction_ds = tf.data.Dataset.from_tensor_slices(tensor_inputs)

  # If shuffle is True, we shuffle the instructions/shards
  if shuffle_files:
    instruction_ds = instruction_ds.shuffle(len(tensor_inputs))

  dataset = instruction_ds.interleave(
      functools.partial(_get_dataset_from_filename,
                        do_skip=do_skip, do_take=do_take),
      cycle_length=parallel_reads,
      block_length=block_length,
      num_parallel_calls=tf.data.experimental.AUTOTUNE,
      )

  # TODO(pierrot): `parse_example` uses
  # `tf.io.parse_single_example`. It might be faster to use `parse_example`,
  # after batching.
  # https://www.tensorflow.org/api_docs/python/tf/io/parse_example
  return dataset.map(parse_fn)
예제 #8
0
def _filenames_equal(
    left: Dict[str, UrlInfo],
    right: Dict[str, UrlInfo],
) -> bool:
    """Compare filenames."""
    return all(l.filename == r.filename
               for _, (l, r) in utils.zip_dict(left, right))
예제 #9
0
 def _stack_nested(sequence_elements):
     if isinstance(sequence_elements[0], dict):
         return {
             # Stack along the first dimension
             k: _stack_nested(sub_sequence)
             for k, sub_sequence in utils.zip_dict(*sequence_elements)
         }
     return stack_arrays(*sequence_elements)
예제 #10
0
def check_splits_equals(splits1, splits2):
    """Check that the two split dicts have the same names and num_shards."""
    if set(splits1) ^ set(splits2):  # Name intersection should be null
        return False
    for _, (split1, split2) in utils.zip_dict(splits1, splits2):
        if split1.num_shards != split2.num_shards:
            return False
    return True
예제 #11
0
def check_splits_equals(splits1, splits2):
    """Check two split dicts have same name, shard_lengths and num_shards."""
    if set(splits1) ^ set(splits2):  # Name intersection should be null
        return False
    for _, (split1, split2) in utils.zip_dict(splits1, splits2):
        if (split1.num_shards != split2.num_shards
                or split1.shard_lengths != split2.shard_lengths):
            return False
    return True
예제 #12
0
def build_dataset(instruction_dicts,
                  dataset_from_file_fn,
                  shuffle_files=False,
                  parallel_reads=64):
    """Constructs a `tf.data.Dataset` from TFRecord files.

  Args:
    instruction_dicts: `list` of {'filepath':, 'mask':}
      containing the information about which files and which examples to use.
      The boolean mask will be repeated and zipped with the examples from
      filepath.
    dataset_from_file_fn: function returning a `tf.data.Dataset` given a
      filename.
    shuffle_files: `bool`, Whether to shuffle the input filenames.
    parallel_reads: `int`, how many files to read in parallel.

  Returns:
    `tf.data.Dataset`
  """
    def instruction_ds_to_file_ds(instruction):
        """Map from instruction to real datasets."""
        examples_ds = dataset_from_file_fn(instruction["filepath"])
        mask_ds = tf.data.Dataset.from_tensor_slices(instruction["mask"])
        mask_ds = mask_ds.repeat(),
        # Zip the mask and real examples
        ds = tf.data.Dataset.zip({
            "example": examples_ds,
            "mask_value": mask_ds,
        })
        # Filter according to the mask (only keep True)
        # Use [0] as from_tensor_slices() yields a tuple
        ds = ds.filter(lambda dataset_dict: dataset_dict["mask_value"][0])
        # Only keep the examples
        ds = ds.map(lambda dataset_dict: dataset_dict["example"])
        return ds

    # Transpose the list[dict] into dict[list]
    tensor_inputs = {
        key: list(values)
        for key, values in utils.zip_dict(*instruction_dicts)
    }
    # Skip slicing if all masks are True (No value skipped)
    if all(all(m) for m in tensor_inputs["mask"]):
        tensor_inputs = tensor_inputs["filepath"]
        instruction_ds_to_file_ds = dataset_from_file_fn

    # Dataset of filenames (or file instructions)
    dataset = tf.data.Dataset.from_tensor_slices(tensor_inputs)
    if shuffle_files:
        dataset = dataset.shuffle(len(instruction_dicts))
    # Use interleave to parallel read files and decode records
    dataset = dataset.interleave(
        instruction_ds_to_file_ds,
        cycle_length=parallel_reads,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    return dataset
예제 #13
0
 def encode_example(self, example_dict):
   """See base class for details."""
   example = {}
   for k, (feature, example_value) in utils.zip_dict(self._feature_dict,
                                                     example_dict):
     try:
       example[k] = feature.encode_example(example_value)
     except Exception as e:  # pylint: disable=broad-except
       utils.reraise(e, prefix=f'In <{feature.__class__.__name__}>'
                     + f' with name "{k}":\n')
   return example
 def _stack_nested(sequence_elements):
     """Recursivelly stack the tensors from the same dict field."""
     if isinstance(sequence_elements[0], dict):
         return {
             # Stack along the first dimension
             k: _stack_nested(sub_sequence)
             for k, sub_sequence in utils.zip_dict(*sequence_elements)
         }
     # Note: As each field can be a nested ragged list, we don't check here
     # that all elements from the list have matching dtype/shape.
     # Checking is done in `example_serializer` when elements
     # are converted to numpy array and stacked togethers.
     return list(sequence_elements)
def _dict_to_tf_example(example_dict, tensor_info_dict):
    """Builds tf.train.Example from (string -> int/float/str list) dictionary.

  Args:
    example_dict: `dict`, dict of values, tensor,...
    tensor_info_dict: `dict` of `tfds.features.TensorInfo`

  Returns:
    example_proto: `tf.train.Example`, the encoded example proto.
  """
    def run_with_reraise(fn, k, example_data, tensor_info):
        try:
            return fn(example_data, tensor_info)
        except Exception as e:  # pylint: disable=broad-except
            utils.reraise(
                e,
                f"Error while serializing feature `{k}`: `{tensor_info}`: ",
            )

    if tensor_info_dict:
        # Add the RaggedTensor fields for the nested sequences
        # Nested sequences are encoded as {'flat_values':, 'row_lengths':}, so need
        # to flatten the example nested dict again.
        # Ex:
        # Input: {'objects/tokens': [[0, 1, 2], [], [3, 4]]}
        # Output: {
        #     'objects/tokens/flat_values': [0, 1, 2, 3, 4],
        #     'objects/tokens/row_lengths_0': [3, 0, 2],
        # }
        example_dict = utils.flatten_nest_dict({
            k: run_with_reraise(_add_ragged_fields, k, example_data,
                                tensor_info)
            for k, (
                example_data,
                tensor_info) in utils.zip_dict(example_dict, tensor_info_dict)
        })
        example_dict = {
            k: run_with_reraise(_item_to_tf_feature, k, item, tensor_info)
            for k, (item, tensor_info) in example_dict.items()
        }
    else:
        # TODO(epot): The following code is only executed in tests and could be
        # cleanned-up, as TensorInfo is always passed to _item_to_tf_feature.
        example_dict = {
            k: run_with_reraise(_item_to_tf_feature, k, example_data, None)
            for k, example_data in example_dict.items()
        }

    return tf.train.Example(features=tf.train.Features(feature=example_dict))
def compare_shapes_and_types(tensor_info, element_spec):
    """Compare shapes and types between TensorInfo and Dataset types/shapes."""
    for feature_name, (feature_info,
                       spec) in utils.zip_dict(tensor_info, element_spec):
        if isinstance(spec, tf.data.DatasetSpec):
            # We use _element_spec because element_spec was added in TF2.5+.
            compare_shapes_and_types(feature_info, spec._element_spec)  # pylint: disable=protected-access
        elif isinstance(feature_info, dict):
            compare_shapes_and_types(feature_info, spec)
        else:
            # Some earlier versions of TF don't expose dtype and shape for the
            # RaggedTensorSpec, so we use the protected versions.
            if feature_info.dtype != spec._dtype:  # pylint: disable=protected-access
                raise TypeError(
                    f"Feature {feature_name} has type {feature_info} but expected {spec}"
                )
            utils.assert_shape_match(feature_info.shape, spec._shape)  # pylint: disable=protected-access
예제 #17
0
  def parse_example(self, serialized_example):
    """Deserialize a single `tf.train.Example` proto.

    Usage:
    ```
    ds = tf.data.TFRecordDataset(filepath)
    ds = ds.map(file_adapter.parse_example)
    ```

    Args:
      serialized_example: `tf.Tensor`, the `tf.string` tensor containing the
        serialized proto to decode.

    Returns:
      example: A nested `dict` of `tf.Tensor` values. The structure and tensors
        shape/dtype match the  `example_specs` provided at construction.
    """
    nested_feature_specs = self._build_feature_specs()

    # Because of RaggedTensor specs, feature_specs can be a 2-level nested dict,
    # so have to wrap `tf.io.parse_single_example` between
    # `flatten_nest_dict`/`pack_as_nest_dict`.
    # {
    #     'video/image': tf.io.FixedLenSequenceFeature(...),
    #     'video/object/bbox': {
    #         'ragged_flat_values': tf.io.FixedLenSequenceFeature(...),
    #         'ragged_row_lengths_0', tf.io.FixedLenSequenceFeature(...),
    #     },
    # }
    flat_feature_specs = utils.flatten_nest_dict(nested_feature_specs)
    example = tf.io.parse_single_example(
        serialized=serialized_example,
        features=flat_feature_specs,
    )
    example = utils.pack_as_nest_dict(example, nested_feature_specs)

    example = {  # pylint:disable=g-complex-comprehension
        k: _deserialize_single_field(example_data, tensor_info)
        for k, (
            example_data,
            tensor_info) in utils.zip_dict(example, self._flat_example_specs)
    }
    # Reconstruct all nesting
    example = utils.pack_as_nest_dict(example, self._example_specs)
    return example
예제 #18
0
def _read_files(
    files: Sequence[FileInstructionDict],
    parse_fn: ParseFn,
    read_config: read_config_lib.ReadConfig,
    shuffle_files: bool,
    num_examples_per_shard: List[int],
) -> tf.data.Dataset:
    """Returns tf.data.Dataset for given file instructions.

  Args:
    files: List[dict(filename, skip, take)], the files information.
      The filenames contain the absolute path, not relative.
      skip/take indicates which example read in the shard: `ds.skip().take()`
    parse_fn: function used to parse each record.
    read_config: Additional options to configure the
      input pipeline (e.g. seed, num parallel reads,...).
    shuffle_files: Defaults to False. True to shuffle input files.
    num_examples_per_shard: if defined, set the cardinality on the
      tf.data.Dataset instance with `tf.data.experimental.with_cardinality`.

  Returns:
    The dataset object.
  """
    # Eventually apply a transformation to the instruction function.
    # This allow the user to have direct control over the interleave order.
    if read_config.experimental_interleave_sort_fn is not None:
        files = read_config.experimental_interleave_sort_fn(files)

    do_skip = any(f['skip'] > 0 for f in files)
    do_take = any(f['take'] > -1 for f in files)

    # Transpose the list[dict] into dict[list]
    tensor_inputs = {
        # skip/take need to be converted to int64 explicitly
        k: list(vals) if k == 'filename' else np.array(vals, dtype=np.int64)
        for k, vals in utils.zip_dict(*files)
    }

    cycle_length = read_config.interleave_cycle_length
    block_length = read_config.interleave_block_length

    instruction_ds = tf.data.Dataset.from_tensor_slices(tensor_inputs)

    # On distributed environement, we can shard per-file if a
    # `tf.distribute.InputContext` object is provided (e.g. from
    # `experimental_distribute_datasets_from_function`)
    if (read_config.input_context
            and read_config.input_context.num_input_pipelines > 1):
        if len(files) < read_config.input_context.num_input_pipelines:
            raise ValueError(
                'Cannot shard the pipeline with given `input_context`.'
                '`num_shards={}` but `num_input_pipelines={}`. '
                'This means that some workers won\'t read any data. '
                'To shard the data, you may want to use the subsplit API '
                'instead: https://www.tensorflow.org/datasets/splits'.format(
                    len(files), read_config.input_context.num_input_pipelines))
        instruction_ds = instruction_ds.shard(
            num_shards=read_config.input_context.num_input_pipelines,
            index=read_config.input_context.input_pipeline_id,
        )

    # If shuffle is True, we shuffle the instructions/shards
    if shuffle_files:
        instruction_ds = instruction_ds.shuffle(
            len(tensor_inputs['filename']),
            seed=read_config.shuffle_seed,
            reshuffle_each_iteration=read_config.
            shuffle_reshuffle_each_iteration,
        )

    ds = instruction_ds.interleave(
        functools.partial(_get_dataset_from_filename,
                          do_skip=do_skip,
                          do_take=do_take),
        cycle_length=cycle_length,
        block_length=block_length,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )

    # If the number of examples read in the tf-record is known, we forward
    # the information to the tf.data.Dataset object.
    # Check the `tf.data.experimental` for backward compatibility with TF <= 2.1
    if (num_examples_per_shard and not read_config.input_context
            and  # TODO(epot): Restore cardinality
            hasattr(tf.data.experimental, 'assert_cardinality')):
        # TODO(b/154963426): Replace by per-shard cardinality.
        cardinality = sum(num_examples_per_shard)
        ds = ds.apply(tf.data.experimental.assert_cardinality(cardinality))

    ds = ds.with_options(read_config.options)  # Additional users options

    # TODO(pierrot): `parse_example` uses
    # `tf.io.parse_single_example`. It might be faster to use `parse_example`,
    # after batching.
    # https://www.tensorflow.org/api_docs/python/tf/io/parse_example
    return ds.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
예제 #19
0
def _read_files(files, parse_fn, read_config, shuffle_files, num_examples):
    """Returns tf.data.Dataset for given file instructions.

  Args:
    files: List[dict(filename, skip, take)], the files information.
      The filenames contain the absolute path, not relative.
      skip/take indicates which example read in the shard: `ds.skip().take()`
    parse_fn (callable): function used to parse each record.
    read_config: `tfds.ReadConfig`, Additional options to configure the
      input pipeline (e.g. seed, num parallel reads,...).
    shuffle_files (bool): Defaults to False. True to shuffle input files.
    num_examples: `int`, if defined, set the cardinality on the
      tf.data.Dataset instance with `tf.data.experimental.with_cardinality`.
  """
    # Eventually apply a transformation to the instruction function.
    # This allow the user to have direct control over the interleave order.
    if read_config.experimental_interleave_sort_fn is not None:
        files = read_config.experimental_interleave_sort_fn(files)

    do_skip = any(f['skip'] > 0 for f in files)
    do_take = any(f['take'] > -1 for f in files)

    # Transpose the list[dict] into dict[list]
    tensor_inputs = {
        # skip/take need to be converted to int64 explicitly
        k: list(vals) if k == 'filename' else np.array(vals, dtype=np.int64)
        for k, vals in utils.zip_dict(*files)
    }

    parallel_reads = read_config.interleave_parallel_reads
    block_length = read_config.interleave_block_length

    instruction_ds = tf.data.Dataset.from_tensor_slices(tensor_inputs)

    # If shuffle is True, we shuffle the instructions/shards
    if shuffle_files:
        instruction_ds = instruction_ds.shuffle(
            len(tensor_inputs['filename']),
            seed=read_config.shuffle_seed,
            reshuffle_each_iteration=read_config.
            shuffle_reshuffle_each_iteration,
        )

    ds = instruction_ds.interleave(
        functools.partial(_get_dataset_from_filename,
                          do_skip=do_skip,
                          do_take=do_take),
        cycle_length=parallel_reads,
        block_length=block_length,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )

    # If the number of examples read in the tf-record is known, we forward
    # the information to the tf.data.Dataset object.
    # Check the `tf.data.experimental` for backward compatibility with TF <= 2.1
    if num_examples and hasattr(tf.data.experimental, 'assert_cardinality'):
        ds = ds.apply(tf.data.experimental.assert_cardinality(num_examples))

    # TODO(tfds): Should merge the default options with read_config to allow users
    # to overwrite the default options.
    ds = ds.with_options(_default_options())  # Default performance options
    ds = ds.with_options(read_config.options)  # Additional users options

    # TODO(pierrot): `parse_example` uses
    # `tf.io.parse_single_example`. It might be faster to use `parse_example`,
    # after batching.
    # https://www.tensorflow.org/api_docs/python/tf/io/parse_example
    return ds.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
예제 #20
0
def _read_single_instruction(instruction, parse_fn, read_config, name, path,
                             name2len, name2shard_lengths, shuffle_files):
    """Returns tf.data.Dataset for given instruction.

  Args:
    instruction (ReadInstruction or str): if str, a ReadInstruction will be
      constructed using `ReadInstruction.from_spec(str)`.
    parse_fn (callable): function used to parse each record.
    read_config: `tfds.ReadConfig`, Additional options to configure the
      input pipeline (e.g. seed, num parallel reads,...).
    name (str): name of the dataset.
    path (str): path to directory where to read tfrecords from.
    name2len: dict associating split names to number of examples.
    name2shard_lengths: dict associating split names to shard lengths.
    shuffle_files (bool): Defaults to False. True to shuffle input files.
  """
    if not isinstance(instruction, ReadInstruction):
        instruction = ReadInstruction.from_spec(instruction)
    absolute_instructions = instruction.to_absolute(name2len)
    files = list(
        itertools.chain.from_iterable([
            _get_dataset_files(name, path, abs_instr, name2shard_lengths)
            for abs_instr in absolute_instructions
        ]))
    if not files:
        msg = 'Instruction "%s" corresponds to no data!' % instruction
        raise AssertionError(msg)
    # Eventually apply a transformation to the instruction function.
    # This allow the user to have direct control over the interleave order.
    if read_config.experimental_interleave_sort_fn is not None:
        files = read_config.experimental_interleave_sort_fn(files)

    do_skip = any(f['skip'] > 0 for f in files)
    do_take = any(f['take'] > -1 for f in files)

    # Transpose the list[dict] into dict[list]
    tensor_inputs = {
        # skip/take need to be converted to int64 explicitly
        k: list(vals) if k == 'filename' else np.array(vals, dtype=np.int64)
        for k, vals in utils.zip_dict(*files)
    }

    parallel_reads = read_config.interleave_parallel_reads
    block_length = read_config.interleave_block_length

    instruction_ds = tf.data.Dataset.from_tensor_slices(tensor_inputs)

    # If shuffle is True, we shuffle the instructions/shards
    if shuffle_files:
        instruction_ds = instruction_ds.shuffle(
            len(tensor_inputs['filename']),
            seed=read_config.shuffle_seed,
            reshuffle_each_iteration=read_config.
            shuffle_reshuffle_each_iteration,
        )

    ds = instruction_ds.interleave(
        functools.partial(_get_dataset_from_filename,
                          do_skip=do_skip,
                          do_take=do_take),
        cycle_length=parallel_reads,
        block_length=block_length,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )

    # TODO(tfds): Should merge the default options with read_config to allow users
    # to overwrite the default options.
    ds = ds.with_options(_default_options())  # Default performance options
    ds = ds.with_options(read_config.options)  # Additional users options

    # TODO(pierrot): `parse_example` uses
    # `tf.io.parse_single_example`. It might be faster to use `parse_example`,
    # after batching.
    # https://www.tensorflow.org/api_docs/python/tf/io/parse_example
    return ds.map(parse_fn)