def _create_or_validate_filenames_dataset(filenames): """Creates (or validates) a dataset of filenames. Args: filenames: Either a list or dataset of filenames. If it is a list, it is convert to a dataset. If it is a dataset, its type and shape is validated. Returns: A dataset of filenames. """ if isinstance(filenames, dataset_ops.DatasetV2): if dataset_ops.get_legacy_output_types(filenames) != dtypes.string: raise TypeError( "`filenames` must be a `tf.data.Dataset` of `tf.string` elements." ) if not dataset_ops.get_legacy_output_shapes( filenames).is_compatible_with(tensor_shape.TensorShape([])): raise TypeError( "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` " "elements.") else: filenames = nest.map_structure(_normalise_fspath, filenames) filenames = ops.convert_to_tensor(filenames, dtype_hint=dtypes.string) if filenames.dtype != dtypes.string: raise TypeError( "`filenames` must be a `tf.Tensor` of dtype `tf.string` dtype." " Got {}".format(filenames.dtype)) filenames = array_ops.reshape(filenames, [-1], name="flat_filenames") filenames = dataset_ops.TensorSliceDataset(filenames, is_files=True) return filenames
def _create_or_validate_filenames_dataset(filenames, name=None): """Creates (or validates) a dataset of filenames. Args: filenames: Either a list or dataset of filenames. If it is a list, it is convert to a dataset. If it is a dataset, its type and shape is validated. name: (Optional.) A name for the tf.data operation. Returns: A dataset of filenames. """ if isinstance(filenames, dataset_ops.DatasetV2): element_type = dataset_ops.get_legacy_output_types(filenames) if element_type != dtypes.string: raise TypeError( "The `filenames` argument must contain `tf.string` elements. Got a " f"dataset of `{element_type!r}` elements.") element_shape = dataset_ops.get_legacy_output_shapes(filenames) if not element_shape.is_compatible_with(tensor_shape.TensorShape([])): raise TypeError( "The `filenames` argument must contain `tf.string` elements of shape " "[] (i.e. scalars). Got a dataset of element shape " f"{element_shape!r}.") else: filenames = nest.map_structure(_normalise_fspath, filenames) filenames = ops.convert_to_tensor(filenames, dtype_hint=dtypes.string) if filenames.dtype != dtypes.string: raise TypeError( "The `filenames` argument must contain `tf.string` elements. Got " f"`{filenames.dtype!r}` elements.") filenames = array_ops.reshape(filenames, [-1], name="flat_filenames") filenames = dataset_ops.TensorSliceDataset(filenames, is_files=True, name=name) return filenames
def from_tensor_slices(tensors): """Creates a `Dataset` whose elements are slices of the given tensors. Args: tensors: A nested structure of tensors, each having the same size in the 0th dimension. Returns: A `Dataset`. """ return Dataset(dataset_ops.TensorSliceDataset(tensors))