Ejemplo n.º 1
0
def _create_or_validate_filenames_dataset(filenames):
    """Creates (or validates) a dataset of filenames.

  Args:
    filenames: Either a list or dataset of filenames. If it is a list, it is
      convert to a dataset. If it is a dataset, its type and shape is validated.

  Returns:
    A dataset of filenames.
  """
    if isinstance(filenames, dataset_ops.DatasetV2):
        if dataset_ops.get_legacy_output_types(filenames) != dtypes.string:
            raise TypeError(
                "`filenames` must be a `tf.data.Dataset` of `tf.string` elements."
            )
        if not dataset_ops.get_legacy_output_shapes(
                filenames).is_compatible_with(tensor_shape.TensorShape([])):
            raise TypeError(
                "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` "
                "elements.")
    else:
        filenames = nest.map_structure(_normalise_fspath, filenames)
        filenames = ops.convert_to_tensor(filenames, dtype_hint=dtypes.string)
        if filenames.dtype != dtypes.string:
            raise TypeError(
                "`filenames` must be a `tf.Tensor` of dtype `tf.string` dtype."
                " Got {}".format(filenames.dtype))
        filenames = array_ops.reshape(filenames, [-1], name="flat_filenames")
        filenames = dataset_ops.TensorSliceDataset(filenames, is_files=True)

    return filenames
Ejemplo n.º 2
0
def _create_or_validate_filenames_dataset(filenames, name=None):
    """Creates (or validates) a dataset of filenames.

  Args:
    filenames: Either a list or dataset of filenames. If it is a list, it is
      convert to a dataset. If it is a dataset, its type and shape is validated.
    name: (Optional.) A name for the tf.data operation.

  Returns:
    A dataset of filenames.
  """
    if isinstance(filenames, dataset_ops.DatasetV2):
        element_type = dataset_ops.get_legacy_output_types(filenames)
        if element_type != dtypes.string:
            raise TypeError(
                "The `filenames` argument must contain `tf.string` elements. Got a "
                f"dataset of `{element_type!r}` elements.")
        element_shape = dataset_ops.get_legacy_output_shapes(filenames)
        if not element_shape.is_compatible_with(tensor_shape.TensorShape([])):
            raise TypeError(
                "The `filenames` argument must contain `tf.string` elements of shape "
                "[] (i.e. scalars). Got a dataset of element shape "
                f"{element_shape!r}.")
    else:
        filenames = nest.map_structure(_normalise_fspath, filenames)
        filenames = ops.convert_to_tensor(filenames, dtype_hint=dtypes.string)
        if filenames.dtype != dtypes.string:
            raise TypeError(
                "The `filenames` argument must contain `tf.string` elements. Got "
                f"`{filenames.dtype!r}` elements.")
        filenames = array_ops.reshape(filenames, [-1], name="flat_filenames")
        filenames = dataset_ops.TensorSliceDataset(filenames,
                                                   is_files=True,
                                                   name=name)
    return filenames
Ejemplo n.º 3
0
 def from_tensor_slices(tensors):
     """Creates a `Dataset` whose elements are slices of the given tensors.
 Args:
   tensors: A nested structure of tensors, each having the same size in the
     0th dimension.
 Returns:
   A `Dataset`.
 """
     return Dataset(dataset_ops.TensorSliceDataset(tensors))