Example #1
0
def make_batched_features_dataset_v2(file_pattern,
                                     batch_size,
                                     features,
                                     reader=None,
                                     label_key=None,
                                     reader_args=None,
                                     num_epochs=None,
                                     shuffle=True,
                                     shuffle_buffer_size=10000,
                                     shuffle_seed=None,
                                     prefetch_buffer_size=None,
                                     reader_num_threads=None,
                                     parser_num_threads=None,
                                     sloppy_ordering=False,
                                     drop_final_batch=False):
  """Returns a `Dataset` of feature dictionaries from `Example` protos.

  If label_key argument is provided, returns a `Dataset` of tuple
  comprising of feature dictionaries and label.

  Example:

  ```
  serialized_examples = [
    features {
      feature { key: "age" value { int64_list { value: [ 0 ] } } }
      feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
      feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } }
    },
    features {
      feature { key: "age" value { int64_list { value: [] } } }
      feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
      feature { key: "kws" value { bytes_list { value: [ "sports" ] } } }
    }
  ]
  ```

  We can use arguments:

  ```
  features: {
    "age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
    "gender": FixedLenFeature([], dtype=tf.string),
    "kws": VarLenFeature(dtype=tf.string),
  }
  ```

  And the expected output is:

  ```python
  {
    "age": [[0], [-1]],
    "gender": [["f"], ["f"]],
    "kws": SparseTensor(
      indices=[[0, 0], [0, 1], [1, 0]],
      values=["code", "art", "sports"]
      dense_shape=[2, 2]),
  }
  ```

  Args:
    file_pattern: List of files or patterns of file paths containing
      `Example` records. See `tf.io.gfile.glob` for pattern rules.
    batch_size: An int representing the number of records to combine
      in a single batch.
    features: A `dict` mapping feature keys to `FixedLenFeature` or
      `VarLenFeature` values. See `tf.io.parse_example`.
    reader: A function or class that can be
      called with a `filenames` tensor and (optional) `reader_args` and returns
      a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
    label_key: (Optional) A string corresponding to the key labels are stored in
      `tf.Examples`. If provided, it must be one of the `features` key,
      otherwise results in `ValueError`.
    reader_args: Additional arguments to pass to the reader class.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If None, cycles through the dataset forever. Defaults to `None`.
    shuffle: A boolean, indicates whether the input should be shuffled. Defaults
      to `True`.
    shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity
      ensures better shuffling but would increase memory usage and startup time.
    shuffle_seed: Randomization seed to use for shuffling.
    prefetch_buffer_size: Number of feature batches to prefetch in order to
      improve performance. Recommended value is the number of batches consumed
      per training step. Defaults to auto-tune.
    reader_num_threads: Number of threads used to read `Example` records. If >1,
      the results will be interleaved. Defaults to `1`.
    parser_num_threads: Number of threads to use for parsing `Example` tensors
      into a dictionary of `Feature` tensors. Defaults to `2`.
    sloppy_ordering: If `True`, reading performance will be improved at
      the cost of non-deterministic ordering. If `False`, the order of elements
      produced is deterministic prior to shuffling (elements are still
      randomized if `shuffle=True`. Note that if the seed is set, then order
      of elements after shuffling is deterministic). Defaults to `False`.
    drop_final_batch: If `True`, and the batch size does not evenly divide the
      input dataset size, the final smaller batch will be dropped. Defaults to
      `False`.

  Returns:
    A dataset of `dict` elements, (or a tuple of `dict` elements and label).
    Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.

  Raises:
    TypeError: If `reader` is a `tf.compat.v1.ReaderBase` subclass.
    ValueError: If `label_key` is not one of the `features` keys.
  """
  if reader is None:
    reader = core_readers.TFRecordDataset

  if reader_num_threads is None:
    reader_num_threads = 1
  if parser_num_threads is None:
    parser_num_threads = 2
  if prefetch_buffer_size is None:
    prefetch_buffer_size = dataset_ops.AUTOTUNE

  # Create dataset of all matching filenames
  dataset = dataset_ops.Dataset.list_files(
      file_pattern, shuffle=shuffle, seed=shuffle_seed)

  if isinstance(reader, type) and issubclass(reader, io_ops.ReaderBase):
    raise TypeError("The `reader` argument must return a `Dataset` object. "
                    "`tf.ReaderBase` subclasses are not supported. For "
                    "example, pass `tf.data.TFRecordDataset` instead of "
                    "`tf.TFRecordReader`.")

  # Read `Example` records from files as tensor objects.
  if reader_args is None:
    reader_args = []

  if reader_num_threads == dataset_ops.AUTOTUNE:
    dataset = dataset.interleave(
        lambda filename: reader(filename, *reader_args),
        num_parallel_calls=reader_num_threads)
    options = dataset_ops.Options()
    options.experimental_deterministic = not sloppy_ordering
    dataset = dataset.with_options(options)
  else:
    # Read files sequentially (if reader_num_threads=1) or in parallel
    def apply_fn(dataset):
      return core_readers.ParallelInterleaveDataset(
          dataset,
          lambda filename: reader(filename, *reader_args),
          cycle_length=reader_num_threads,
          block_length=1,
          sloppy=sloppy_ordering,
          buffer_output_elements=None,
          prefetch_input_elements=None)

    dataset = dataset.apply(apply_fn)

  # Extract values if the `Example` tensors are stored as key-value tuples.
  if dataset_ops.get_legacy_output_types(dataset) == (
      dtypes.string, dtypes.string):
    dataset = dataset_ops.MapDataset(
        dataset, lambda _, v: v, use_inter_op_parallelism=False)

  # Apply dataset repeat and shuffle transformations.
  dataset = _maybe_shuffle_and_repeat(
      dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)

  # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
  # improve the shape inference, because it makes the batch dimension static.
  # It is safe to do this because in that case we are repeating the input
  # indefinitely, and all batches will be full-sized.
  dataset = dataset.batch(
      batch_size, drop_remainder=drop_final_batch or num_epochs is None)

  # Parse `Example` tensors to a dictionary of `Feature` tensors.
  dataset = dataset.apply(
      parsing_ops.parse_example_dataset(
          features, num_parallel_calls=parser_num_threads))

  if label_key:
    if label_key not in features:
      raise ValueError(
          "The `label_key` provided (%r) must be one of the `features` keys." %
          label_key)
    dataset = dataset.map(lambda x: (x, x.pop(label_key)))

  dataset = dataset.prefetch(prefetch_buffer_size)
  return dataset
Example #2
0
def make_csv_dataset_v2(
    file_pattern,
    batch_size,
    column_names=None,
    column_defaults=None,
    label_name=None,
    select_columns=None,
    field_delim=",",
    use_quote_delim=True,
    na_value="",
    header=True,
    num_epochs=None,
    shuffle=True,
    shuffle_buffer_size=10000,
    shuffle_seed=None,
    prefetch_buffer_size=dataset_ops.AUTOTUNE,
    num_parallel_reads=1,
    sloppy=False,
    num_rows_for_inference=100,
    compression_type=None,
    ignore_errors=False,
):
  """Reads CSV files into a dataset.

  Reads CSV files into a dataset, where each element is a (features, labels)
  tuple that corresponds to a batch of CSV rows. The features dictionary
  maps feature column names to `Tensor`s containing the corresponding
  feature data, and labels is a `Tensor` containing the batch's label data.

  Args:
    file_pattern: List of files or patterns of file paths containing CSV
      records. See `tf.io.gfile.glob` for pattern rules.
    batch_size: An int representing the number of records to combine
      in a single batch.
    column_names: An optional list of strings that corresponds to the CSV
      columns, in order. One per column of the input record. If this is not
      provided, infers the column names from the first row of the records.
      These names will be the keys of the features dict of each dataset element.
    column_defaults: A optional list of default values for the CSV fields. One
      item per selected column of the input record. Each item in the list is
      either a valid CSV dtype (float32, float64, int32, int64, or string), or a
      `Tensor` with one of the aforementioned types. The tensor can either be
      a scalar default value (if the column is optional), or an empty tensor (if
      the column is required). If a dtype is provided instead of a tensor, the
      column is also treated as required. If this list is not provided, tries
      to infer types based on reading the first num_rows_for_inference rows of
      files specified, and assumes all columns are optional, defaulting to `0`
      for numeric values and `""` for string values. If both this and
      `select_columns` are specified, these must have the same lengths, and
      `column_defaults` is assumed to be sorted in order of increasing column
      index.
    label_name: A optional string corresponding to the label column. If
      provided, the data for this column is returned as a separate `Tensor` from
      the features dictionary, so that the dataset complies with the format
      expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input
      function.
    select_columns: An optional list of integer indices or string column
      names, that specifies a subset of columns of CSV data to select. If
      column names are provided, these must correspond to names provided in
      `column_names` or inferred from the file header lines. When this argument
      is specified, only a subset of CSV columns will be parsed and returned,
      corresponding to the columns specified. Using this results in faster
      parsing and lower memory usage. If both this and `column_defaults` are
      specified, these must have the same lengths, and `column_defaults` is
      assumed to be sorted in order of increasing column index.
    field_delim: An optional `string`. Defaults to `","`. Char delimiter to
      separate fields in a record.
    use_quote_delim: An optional bool. Defaults to `True`. If false, treats
      double quotation marks as regular characters inside of the string fields.
    na_value: Additional string to recognize as NA/NaN.
    header: A bool that indicates whether the first rows of provided CSV files
      correspond to header lines with column names, and should not be included
      in the data.
    num_epochs: An int specifying the number of times this dataset is repeated.
      If None, cycles through the dataset forever.
    shuffle: A bool that indicates whether the input should be shuffled.
    shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
      ensures better shuffling, but increases memory usage and startup time.
    shuffle_seed: Randomization seed to use for shuffling.
    prefetch_buffer_size: An int specifying the number of feature
      batches to prefetch for performance improvement. Recommended value is the
      number of batches consumed per training step. Defaults to auto-tune.

    num_parallel_reads: Number of threads used to read CSV records from files.
      If >1, the results will be interleaved.
    sloppy: If `True`, reading performance will be improved at
      the cost of non-deterministic ordering. If `False`, the order of elements
      produced is deterministic prior to shuffling (elements are still
      randomized if `shuffle=True`. Note that if the seed is set, then order
      of elements after shuffling is deterministic). Defaults to `False`.
    num_rows_for_inference: Number of rows of a file to use for type inference
      if record_defaults is not provided. If None, reads all the rows of all
      the files. Defaults to 100.
    compression_type: (Optional.) A `tf.string` scalar evaluating to one of
      `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression.
    ignore_errors: (Optional.) If `True`, ignores errors with CSV file parsing,
      such as malformed data or empty lines, and moves on to the next valid
      CSV record. Otherwise, the dataset raises an error and stops processing
      when encountering any invalid records. Defaults to `False`.

  Returns:
    A dataset, where each element is a (features, labels) tuple that corresponds
    to a batch of `batch_size` CSV rows. The features dictionary maps feature
    column names to `Tensor`s containing the corresponding column data, and
    labels is a `Tensor` containing the column data for the label column
    specified by `label_name`.

  Raises:
    ValueError: If any of the arguments is malformed.
  """
  # Create dataset of all matching filenames
  filenames = _get_file_names(file_pattern, False)
  dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
  if shuffle:
    dataset = dataset.shuffle(len(filenames), shuffle_seed)

  # Clean arguments; figure out column names and defaults

  if column_names is None:
    if not header:
      raise ValueError("Cannot infer column names without a header line.")
    # If column names are not provided, infer from the header lines
    column_names = _infer_column_names(filenames, field_delim, use_quote_delim)
  if len(column_names) != len(set(column_names)):
    raise ValueError("Cannot have duplicate column names.")

  if select_columns is not None:
    select_columns = _get_sorted_col_indices(select_columns, column_names)

  if column_defaults is not None:
    column_defaults = [
        constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
        for x in column_defaults
    ]
  else:
    # If column defaults are not provided, infer from records at graph
    # construction time
    column_defaults = _infer_column_defaults(
        filenames, len(column_names), field_delim, use_quote_delim, na_value,
        header, num_rows_for_inference, select_columns)

  if select_columns is not None and len(column_defaults) != len(select_columns):
    raise ValueError(
        "If specified, column_defaults and select_columns must have same "
        "length."
    )
  if select_columns is not None and len(column_names) > len(select_columns):
    # Pick the relevant subset of column names
    column_names = [column_names[i] for i in select_columns]

  if label_name is not None and label_name not in column_names:
    raise ValueError("`label_name` provided must be one of the columns.")

  def filename_to_dataset(filename):
    dataset = CsvDataset(
        filename,
        record_defaults=column_defaults,
        field_delim=field_delim,
        use_quote_delim=use_quote_delim,
        na_value=na_value,
        select_cols=select_columns,
        header=header,
        compression_type=compression_type
    )
    if ignore_errors:
      dataset = dataset.apply(error_ops.ignore_errors())
    return dataset

  def map_fn(*columns):
    """Organizes columns into a features dictionary.

    Args:
      *columns: list of `Tensor`s corresponding to one csv record.
    Returns:
      An OrderedDict of feature names to values for that particular record. If
      label_name is provided, extracts the label feature to be returned as the
      second element of the tuple.
    """
    features = collections.OrderedDict(zip(column_names, columns))
    if label_name is not None:
      label = features.pop(label_name)
      return features, label
    return features

  # Read files sequentially (if num_parallel_reads=1) or in parallel
  dataset = dataset.apply(
      interleave_ops.parallel_interleave(
          filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))

  dataset = _maybe_shuffle_and_repeat(
      dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)

  # Apply batch before map for perf, because map has high overhead relative
  # to the size of the computation in each map.
  # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
  # improve the shape inference, because it makes the batch dimension static.
  # It is safe to do this because in that case we are repeating the input
  # indefinitely, and all batches will be full-sized.
  dataset = dataset.batch(batch_size=batch_size,
                          drop_remainder=num_epochs is None)
  dataset = dataset_ops.MapDataset(
      dataset, map_fn, use_inter_op_parallelism=False)
  dataset = dataset.prefetch(prefetch_buffer_size)

  return dataset
Example #3
0
class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
    def _buildMapDataset(self, components, count):
        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
            _map_fn).repeat(count)
        self.assertEqual([c.shape[1:] for c in components],
                         [shape for shape in dataset.output_shapes])
        return dataset

    def testMapDataset(self):
        """Test an dataset that maps a TF function across its input elements."""
        # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
        # RepeatDataset(count).
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))

        # Test single-threaded access to the iterator.
        get_next = self.getNext(self._buildMapDataset(components, 14))
        for _ in range(14):
            for i in range(7):
                result = self.evaluate(get_next())
                for component, result_component in zip(components, result):
                    self.assertAllEqual(component[i]**2, result_component)
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    # TODO(b/117581999): add eager coverage, different threads run in graph
    # context.
    @test_util.run_v1_only("b/120545219")
    def testSkipEagerMapDatasetMultithreaded(self):
        # Test multi-threaded access to the same iterator.
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))
        get_next = self.getNext(self._buildMapDataset(components, 18))
        results = []
        with self.cached_session() as sess:

            def iterator_thread():
                while True:
                    try:
                        results.append(sess.run(get_next()))
                    except errors.OutOfRangeError:
                        return

            threads = [
                self.checkedThread(target=iterator_thread) for _ in range(8)
            ]
            for t in threads:
                t.start()
            for t in threads:
                t.join()

            # `results` will contain the same elements components**2
            # repeated 18 times, but in a non-deterministic order. Sort the
            # results, and assert that each element of components**2 is
            # produced 18 times.
            results.sort(key=lambda x: x[0])
            for i in range(7):
                for j in range(18):
                    for component, result_component in zip(
                            components, results[i * 18 + j]):
                        self.assertAllEqual(component[i]**2, result_component)

    def _buildParallelMapDataset(self, components, count, num_parallel_calls,
                                 output_buffer_size):
        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
            _map_fn, num_parallel_calls=num_parallel_calls).prefetch(
                output_buffer_size).repeat(count)

        self.assertEqual([c.shape[1:] for c in components],
                         [shape for shape in dataset.output_shapes])
        return dataset

    def testParallelMapDataset(self):
        """Test an dataset that maps a TF function across its input elements."""

        # The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) ->
        # RepeatDataset(count).
        def do_test(num_parallel_calls, output_buffer_size):

            components = (np.arange(7),
                          np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                          np.array(37.0) * np.arange(7))
            # Test single-threaded access to the iterator.
            get_next = self.getNext(
                self._buildParallelMapDataset(components, 14,
                                              num_parallel_calls,
                                              output_buffer_size))
            for _ in range(14):
                for i in range(7):
                    result = self.evaluate(get_next())
                    for component, result_component in zip(components, result):
                        self.assertAllEqual(component[i]**2, result_component)
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(get_next())

        for num_parallel_calls_val, output_buffer_size_val in [(1, 1), (1, 2),
                                                               (2, 2), (2, 4),
                                                               (8, 8),
                                                               (8, 16)]:
            do_test(num_parallel_calls_val, output_buffer_size_val)

    # TODO(b/117581999): add eager coverage, different threads run in graph
    # context.
    @test_util.run_v1_only("b/120545219")
    def testSkipEagerParallelMapDatasetMultithreaded(self):
        def do_test(num_parallel_calls, output_buffer_size):
            # Test multi-threaded access to the same iterator.
            components = (np.arange(7),
                          np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                          np.array(37.0) * np.arange(7))
            get_next = self.getNext(
                self._buildParallelMapDataset(components, 18,
                                              num_parallel_calls,
                                              output_buffer_size))
            results = []
            with self.cached_session() as sess:

                def iterator_thread():
                    while True:
                        try:
                            results.append(sess.run(get_next()))
                        except errors.OutOfRangeError:
                            return

                threads = [
                    self.checkedThread(target=iterator_thread)
                    for _ in range(64)
                ]
                for t in threads:
                    t.start()
                for t in threads:
                    t.join()

                # `results` will contain the same elements components**2
                # repeated 18 times, but in a non-deterministic order. Sort the
                # results, and assert that each element of components**2 is
                # produced 18 times.
                results.sort(key=lambda x: x[0])
                for i in range(7):
                    for j in range(18):
                        for component, result_component in zip(
                                components, results[i * 18 + j]):
                            self.assertAllEqual(component[i]**2,
                                                result_component)

            for num_parallel_calls_val, output_buffer_size_val in [(1, 1),
                                                                   (1, 2),
                                                                   (2, 2),
                                                                   (2, 4),
                                                                   (8, 8),
                                                                   (8, 16)]:
                do_test(num_parallel_calls_val, output_buffer_size_val)

    def testImplicitDisposeParallelMapDataset(self):
        # Tests whether a parallel map dataset will be cleaned up correctly when
        # the pipeline does not run it until exhaustion.
        # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
        # RepeatDataset(1000).
        components = (np.arange(1000),
                      np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
                      np.array(37.0) * np.arange(1000))

        dataset = self._buildParallelMapDataset(components, 1000, 100, 100)
        # NOTE(mrry): Also test that the prefetching thread is cancelled correctly.
        dataset = dataset.prefetch(100)
        get_next = self.getNext(dataset)

        for _ in range(3):
            self.evaluate(get_next())

    def testParallelMapUnspecifiedOutputSize(self):
        components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)

        dataset = (dataset_ops.Dataset.from_tensor_slices(components).map(
            lambda x: array_ops.check_numerics(x, "message"),
            num_parallel_calls=2))
        get_next = self.getNext(dataset)

        for _ in range(3):
            self.evaluate(get_next())

    def testParallelMapError(self):
        components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)

        dataset = (dataset_ops.Dataset.from_tensor_slices(components).map(
            lambda x: array_ops.check_numerics(x, "message"),
            num_parallel_calls=2))
        get_next = self.getNext(dataset)

        for _ in range(3):
            self.evaluate(get_next())
        # The 4th element is NaN, so `array_ops.check_numerics()` should fail.
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(get_next())
        self.evaluate(get_next())
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    def testPrefetchError(self):
        components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)

        dataset = (dataset_ops.Dataset.from_tensor_slices(components).map(
            lambda x: array_ops.check_numerics(x, "message")).prefetch(2))

        get_next = self.getNext(dataset)

        for _ in range(3):
            self.evaluate(get_next())
        # The 4th element is NaN, so `array_ops.check_numerics()` should fail.
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(get_next())
        self.evaluate(get_next())
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    def testCaptureIterator(self):
        def _build_ds(iterator):
            def _map_fn(x):
                get_next = iterator.get_next()
                return x * get_next

            return dataset_ops.Dataset.range(10).map(_map_fn)

        def _build_graph():
            if context.executing_eagerly():
                captured_iterator = iter(dataset_ops.Dataset.range(10))
            else:
                captured_iterator = dataset_ops.Dataset.range(
                    10).make_initializable_iterator()
            ds = _build_ds(captured_iterator)
            return captured_iterator, ds

        captured_iter, ds = _build_graph()
        if not context.executing_eagerly():
            self.evaluate(captured_iter.initializer)
        get_next = self.getNext(ds, requires_initialization=True)
        for i in range(10):
            self.assertEqual(i * i, self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    def testCaptureHashTable(self):
        # NOTE(mrry): We must use the V2 variants of `HashTable`
        # etc. because these produce a `tf.resource`-typed output that is
        # compatible with the in-graph function implementation.
        default_val = -1
        keys = constant_op.constant(["brain", "salad", "surgery"])
        values = constant_op.constant([0, 1, 2], dtypes.int64)
        table = lookup_ops.HashTable(
            lookup_ops.KeyValueTensorInitializer(keys, values), default_val)

        input_sentences = dataset_ops.Dataset.from_tensor_slices(
            ["brain brain tank salad surgery", "surgery brain"])

        dataset = input_sentences.map(
            lambda x: string_ops.string_split([x]).values).map(table.lookup)

        get_next = self.getNext(dataset, requires_initialization=True)

        self.evaluate(table.initializer)
        self.evaluate(get_next())
        self.evaluate(get_next())
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    def testCaptureQueue(self):
        elements = np.random.randint(100, size=[200])
        queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])
        enqueue_op = queue.enqueue_many(elements)
        close_op = queue.close()
        dataset = dataset_ops.Dataset.from_tensors(0).repeat(-1).map(
            lambda _: queue.dequeue())

        get_next = self.getNext(dataset, requires_initialization=True)
        self.evaluate(enqueue_op)
        self.evaluate(close_op)

        for element in elements:
            self.assertEqual(element, self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    # TODO(b/117581999): Possible deadlock in eager mode, debug.
    @test_util.run_v1_only("b/120545219")
    def testSkipEagerCaptureSameResourceMultipleTimes(self):
        elements = np.random.randint(100, size=[200])
        queue = data_flow_ops.FIFOQueue(200,
                                        dtypes.int64,
                                        shapes=[],
                                        shared_name="shared_queue")
        queue_2 = data_flow_ops.FIFOQueue(200,
                                          dtypes.int64,
                                          shapes=[],
                                          shared_name="shared_queue")

        enqueue_op = queue.enqueue_many(elements)
        close_op = queue.close()

        dataset = dataset_ops.Dataset.from_tensors(0).repeat(-1).map(
            lambda _: (queue.dequeue(), queue_2.dequeue()))

        self.evaluate(enqueue_op)
        self.evaluate(close_op)
        get_next = self.getNext(dataset, requires_initialization=True)
        for i in range(100):
            self.assertCountEqual([elements[i * 2], elements[i * 2 + 1]],
                                  self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    def testCaptureVariable(self):
        counter_var = variable_scope.get_variable("counter", (),
                                                  dtypes.int32,
                                                  use_resource=True)
        dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).map(
            lambda _: counter_var.assign_add(1))
        get_next = self.getNext(dataset, requires_initialization=True)

        self.evaluate(counter_var.initializer)

        for i in range(10):
            self.assertEqual(i, self.evaluate(counter_var))
            self.assertEqual(i + 1, self.evaluate(get_next()))
        self.assertEqual(10, self.evaluate(counter_var))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
        self.assertEqual(10, self.evaluate(counter_var))

    # TODO(b/117581999): error not captured for eager mode, debug.
    @test_util.run_v1_only("b/120545219")
    def testSkipEagerCaptureUninitializedVariableError(self):
        counter_var = variable_scope.get_variable("counter", (),
                                                  dtypes.int32,
                                                  use_resource=True)
        dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).map(
            lambda _: counter_var.assign_add(1))

        get_next = self.getNext(dataset, requires_initialization=True)

        with self.assertRaises(errors.NotFoundError):
            self.evaluate(get_next())

    def testSeededStatefulOperatorIsProperlyStateful(self):
        dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).map(
            lambda _: random_ops.random_uniform((), seed=11)).batch(2)

        get_next = self.getNext(dataset, requires_initialization=True)
        random_values = []
        with self.assertRaises(errors.OutOfRangeError):
            while True:
                random_values.extend(self.evaluate(get_next()))
        self.assertLen(random_values, 10)
        self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6)

        get_next = self.getNext(dataset, requires_initialization=True)
        random_values_2 = []
        with self.assertRaises(errors.OutOfRangeError):
            while True:
                random_values_2.extend(self.evaluate(get_next()))

        # Randomness is repeatable given same seed
        self.assertAllClose(random_values, random_values_2)

    def testStatefulMapKeepsStateAcrossIterators(self):
        dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).map(
            lambda _: random_ops.random_uniform((), seed=11)).repeat(
                1000).batch(10)

        get_next = self.getNext(dataset)
        random_values = self.evaluate(get_next())

        # Assert that one of the next 99 batches yielded by the iterator is
        # different from the first.
        i = 0
        while i < 99:
            if np.any(random_values != self.evaluate(get_next())):
                break
            i += 1
        self.assertLess(i, 99)

    def testStatefulOperationInShortCircuit(self):
        counter_var = variable_scope.get_variable("counter", (),
                                                  dtypes.int32,
                                                  use_resource=True)

        def increment_fn(x):
            counter_var.assign_add(1)
            return x

        dataset = dataset_ops.Dataset.range(10).map(increment_fn)

        get_next = self.getNext(dataset, requires_initialization=True)

        self.evaluate(counter_var.initializer)
        for i in range(10):
            self.assertEqual(i, self.evaluate(counter_var))
            self.assertEqual(i, self.evaluate(get_next()))
        self.assertEqual(10, self.evaluate(counter_var))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
        self.assertEqual(10, self.evaluate(counter_var))

    def testMapDict(self):
        dataset = dataset_ops.Dataset.range(10).map(lambda x: {
            "foo": x * 2,
            "bar": x**2
        }).map(lambda d: d["foo"] + d["bar"])
        self.assertDatasetProduces(
            dataset, expected_output=[i * 2 + i**2 for i in range(10)])

    def testMapNamedtuple(self, count=10):
        # construct dataset of tuples
        labels = dataset_ops.Dataset.range(count)
        images = labels.map(lambda l: -l)
        dataset_tuple = dataset_ops.Dataset.zip((labels, images))

        # convert dataset of tuples to dataset of namedtuples
        example = namedtuple("Example", ["label", "image"])
        dataset_namedtuple = dataset_tuple.map(example)

        def preprocess_tuple(label, image):
            image = 2 * image
            return label, image

        def preprocess_namedtuple(example):
            return example._replace(image=2 * example.image)

        # preprocess both datasets
        dataset_tuple = dataset_tuple.map(preprocess_tuple)
        dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple)

        next_tuple = self.getNext(dataset_tuple)
        next_namedtuple = self.getNext(dataset_namedtuple)

        # make sure both datasets contain the same data
        for i in range(count):
            tuple_, namedtuple_ = self.evaluate(
                [next_tuple(), next_namedtuple()])
            self.assertEqual(tuple_, namedtuple_)
            self.assertEqual(tuple_, (i, -2 * i))

        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_namedtuple())

    def testUseStepContainerInMap(self):
        row = np.arange(6)
        dataset = dataset_ops.Dataset.from_tensors(row).map(
            lambda elems: functional_ops.map_fn(lambda x: x * x, elems))
        self.assertDatasetProduces(dataset, expected_output=[row**2])

    def testCaseAndCondInMap(self):
        def control_map_fn(x, y):
            def multiply():
                return x * 2

            def divide():
                return x // 2

            def defaults_two():
                return control_flow_ops.cond(math_ops.equal(
                    math_ops.mod(x, 2), 0),
                                             multiply,
                                             divide,
                                             name="cond_mult")

            pred_fn_pairs = {
                math_ops.logical_or(math_ops.equal(y, 2), math_ops.equal(y, 3)):
                defaults_two,
            }

            return control_flow_ops.case(pred_fn_pairs,
                                         default=multiply,
                                         exclusive=True)

        def build_dataset(row, num):
            dataset = dataset_ops.Dataset.from_tensor_slices(row).map(
                lambda x: control_map_fn(x, num))
            return self.getNext(dataset)

        row = np.arange(6)
        for num in [2, 3, 4]:
            get_next = build_dataset(row, num)
            for i in range(6):
                self.assertEqual((i // 2 if i % 2 else i * 2) if
                                 (num == 2 or num == 3) else i * 2,
                                 self.evaluate(get_next()))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(get_next())

    def testCaseInWhileInMap(self):
        def control_map_fn(x, y):
            def multiply():
                return x * 2

            def divide():
                return x // 2

            pred_fn_pairs = {
                math_ops.logical_or(math_ops.equal(y, 2), math_ops.equal(y, 3)):
                divide,
            }

            return control_flow_ops.case(pred_fn_pairs,
                                         default=multiply,
                                         exclusive=True)

        def build_dataset(row, num):
            # pylint: disable=g-long-lambda
            dataset = dataset_ops.Dataset.from_tensors(row).map(
                lambda elems: functional_ops.map_fn(
                    lambda x: control_map_fn(x, num), elems))
            return self.getNext(dataset)

        row = np.arange(6)
        for num in [2, 3, 4]:
            get_next = build_dataset(row, num)
            self.assertAllEqual(
                [x // 2 if (num == 2 or num == 3) else x * 2 for x in row],
                self.evaluate(get_next()))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(get_next())

    def testCaseAndCondInWhileInMap(self):
        def control_map_fn(x, y):
            def multiply():
                return x * 2

            def divide():
                return x // 2

            def defaults_two():
                return control_flow_ops.cond(math_ops.equal(
                    math_ops.mod(x, 2), 0),
                                             multiply,
                                             divide,
                                             name="cond_mult")

            pred_fn_pairs = {
                math_ops.logical_or(math_ops.equal(y, 2), math_ops.equal(y, 3)):
                defaults_two,
            }

            return control_flow_ops.case(pred_fn_pairs,
                                         default=multiply,
                                         exclusive=True)

        row = np.arange(6)
        num = 2
        # pylint: disable=g-long-lambda
        dataset = dataset_ops.Dataset.from_tensors(row).map(
            lambda elems: functional_ops.map_fn(
                lambda x: control_map_fn(x, num), elems))
        # pylint: enable=g-long-lambda
        get_next = self.getNext(dataset)

        self.assertAllEqual([(x // 2 if x % 2 else x * 2) if
                             (num == 2 or num == 3) else x * 2 for x in row],
                            self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    def testPrefetch(self):
        # We will use this event to test that `_map_py_func()` has been
        # invoked a certain number of times (6 times, to be exact) after
        # consuming fewer elements from the iterator.
        ev = threading.Event()

        set_event_during_invocation = 5

        def _map_py_func(x):
            if x == set_event_during_invocation:
                ev.set()
            return x * x

        def _map_fn(x):
            return script_ops.py_func(_map_py_func, [x], x.dtype)

        def do_test(buffer_size):
            dataset = dataset_ops.Dataset.range(100).map(_map_fn).prefetch(
                buffer_size)

            get_next = self.getNext(dataset)
            # Simple test that prefetch yields the expected values in the
            # expected order.
            for i in range(100):
                self.assertEqual(i * i, self.evaluate(get_next()))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(get_next())

        for buffer_size in [1, 10, 100, 1000]:
            do_test(buffer_size)

        # We can indirectly observe that varying the buffer size has the
        # intended effect by observing when `ev` is set (on the 6th
        # invocation of `_map_py_func()`).
        # NOTE(mrry): We do not test with `buffer_size ==
        # set_event_during_invocation`, because we must consume at least
        # one element to start the prefetching.
        def do_test_ev(buffer_size):
            dataset = dataset_ops.Dataset.range(100).map(_map_fn).prefetch(
                buffer_size)

            get_next = self.getNext(dataset)

            event_will_be_set_after_consuming = (set_event_during_invocation -
                                                 buffer_size + 1)

            ev.clear()
            for i in range(event_will_be_set_after_consuming):
                self.assertFalse(ev.is_set())
                self.assertEqual(i * i, self.evaluate(get_next()))
            ev.wait()
            for i in range(event_will_be_set_after_consuming, 100):
                self.assertEqual(i * i, self.evaluate(get_next()))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(get_next())

        for buffer_size in range(1, set_event_during_invocation):
            do_test_ev(buffer_size)

    def testReturnList(self):
        dataset = dataset_ops.Dataset.range(10).map(
            lambda x: [x, constant_op.constant(37.0)])
        self.assertDatasetProduces(dataset,
                                   expected_output=[(i, 37.0)
                                                    for i in range(10)])

    def testMultiOutputPyFunc(self):
        # The `tf.py_func()` op returns a list of tensors for its outputs.
        def _map_fn(x_tensor):
            def _map_py_func(x):
                return x, np.array(37.0, dtype=np.float64)

            return script_ops.py_func(_map_py_func, [x_tensor],
                                      [dtypes.int64, dtypes.float64])

        dataset = dataset_ops.Dataset.range(10).map(_map_fn)
        self.assertDatasetProduces(dataset,
                                   expected_output=[(i, 37.0)
                                                    for i in range(10)])

    def testSparse(self):
        def _sparse(i):
            return sparse_tensor.SparseTensorValue(indices=np.array([[0, 0]]),
                                                   values=(i * np.array([1])),
                                                   dense_shape=np.array([1,
                                                                         1]))

        dataset = dataset_ops.Dataset.range(10).map(_sparse)
        self.assertDatasetProduces(
            dataset, expected_output=[_sparse(i) for i in range(10)])

    def testSparseChain(self):
        def _sparse(i):
            return sparse_tensor.SparseTensorValue(indices=np.array([[0, 0]]),
                                                   values=(i * np.array([1])),
                                                   dense_shape=np.array([1,
                                                                         1]))

        def _check(i):
            self.assertTrue(sparse_tensor.is_sparse(i))
            return sparse_ops.sparse_concat(0, [i, i])

        dataset = dataset_ops.Dataset.range(10).map(_sparse).map(_check)

        self.assertDatasetProduces(dataset,
                                   expected_output=[
                                       self.evaluate(_check(_sparse(i)))
                                       for i in range(10)
                                   ])

    def testParallelMapOutOfRangeError(self):
        def raising_py_func(i):
            if i == 100:
                raise StopIteration()
            else:
                return i

        dataset = dataset_ops.Dataset.range(105).map(
            lambda x: script_ops.py_func(raising_py_func, [x], dtypes.int64),
            num_parallel_calls=2)
        get_next = self.getNext(dataset)
        for i in range(100):
            self.assertEqual(i, self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    def testConstantOutput(self):
        dataset = dataset_ops.Dataset.range(10).map(lambda x: [x, "hello", 10])
        self.assertDatasetProduces(dataset,
                                   [(i, b"hello", 10) for i in range(10)])

    def testWarnOnLookupTable(self):
        def collecting_function(x):
            _ = lookup_ops.HashTable(lookup_ops.KeyValueTensorInitializer([],
                                                                          []),
                                     0.0,
                                     name="t1")
            return x

        warnings.simplefilter("always")
        with warnings.catch_warnings(record=True) as w:
            _ = dataset_ops.Dataset.range(10).map(collecting_function)
        # NOTE(mrry): Python 3 prints other warnings in addition to the one we are
        # testing, so we search for the expected warning.
        self.assertGreaterEqual(len(w), 1)
        found_warning = False
        for warning in w:
            if ("Creating lookup tables inside a function passed to Dataset.map() is "
                    "not supported." in str(warning)):
                found_warning = True
                break
        self.assertTrue(found_warning)

    def testNestedDatasetMap(self):
        # TODO(b/110122868): When iterators can yield a `tf.data.Dataset`, remove
        # the `get_single_element()` call.
        dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]).map(
            dataset_ops.Dataset.from_tensor_slices).map(
                lambda ds: ds.batch(3)).flat_map(lambda x: x)

        self.assertDatasetProduces(dataset, expected_output=[[1.0, 2.0, 3.0]])

    def testReturnValueError(self):
        dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
        with self.assertRaisesRegexp(
                TypeError, r"Unsupported return value from function passed to "
                r"Dataset.map\(\): None."):
            _ = dataset.map(lambda x: None)

    def testBrokenFunctionErrorOnInitialization(self):
        dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0, 3.0])

        def broken_function(_):
            """A function deliberately designed to fail on instantiation."""
            value = []
            tensor_value = attr_value_pb2.AttrValue()
            tensor_value.tensor.CopyFrom(
                tensor_util.make_tensor_proto(value,
                                              dtype=dtypes.float32,
                                              shape=[0],
                                              verify_shape=False))
            dtype_value = attr_value_pb2.AttrValue(
                type=dtypes.int32.as_datatype_enum)

            # Create a "Const" op with a `tf.float32` value and a `tf.int32` type
            # attr.
            const_tensor = ops.get_default_graph().create_op(
                "Const", [], [dtypes.int32],
                attrs={
                    "value": tensor_value,
                    "dtype": dtype_value
                },
                name="BrokenConst").outputs[0]
            return const_tensor

        dataset = dataset.map(broken_function)
        self.assertDatasetProduces(dataset,
                                   expected_error=(errors.InvalidArgumentError,
                                                   "BrokenConst"))

# pylint: disable=g-long-lambda

    @parameterized.named_parameters(
        ("Map", lambda dataset, func: dataset_ops.MapDataset(
            dataset, func, use_inter_op_parallelism=False)),
        ("ParallelMap", lambda dataset, func: dataset_ops.ParallelMapDataset(
            dataset,
            func,
            num_parallel_calls=1,
            use_inter_op_parallelism=False)),
    )
    def testNoInterOpParallelism(self, make_dataset_fn):
        dataset = dataset_ops.Dataset.from_tensors(0)

        def _get_tid():
            return np.int64(threading.current_thread().ident)

        def _map_fn(_):
            tids = []
            for _ in range(10):
                tids.append(script_ops.py_func(_get_tid, [], dtypes.int64))
            return tids

        dataset = make_dataset_fn(dataset, _map_fn)
        get_next = self.getNext(dataset)

        tids = self.evaluate(get_next())
        self.assertTrue(all(tids[0] == tid for tid in tids))


# pylint: enable=g-long-lambda

    @parameterized.named_parameters(
        ("SequentialIdentity", None, lambda x: x, None),
        ("SequentialReplicate", None, lambda x: (x, x), None),
        ("SequentialSwap", (None, None), lambda x, y: (y, x), None),
        ("SequentialProject", (None, None), lambda x, y: x, None),
        ("ParallelIdentity", None, lambda x: x, 10),
        ("ParallelReplicate", None, lambda x: (x, x), 10),
        ("ParallelSwap", (None, None), lambda x, y: (y, x), 10),
        ("ParallelProject", (None, None), lambda x, y: x, 10),
    )
    def testShortCircuit(self, structure, map_fn, num_parallel_calls):
        dataset = self.structuredDataset(structure).repeat().map(
            map_fn, num_parallel_calls=num_parallel_calls)
        get_next = self.getNext(dataset)

        if isinstance(structure, tuple):
            expected = map_fn(
                *self.evaluate(self.structuredElement(structure)))
        else:
            expected = map_fn(self.evaluate(self.structuredElement(structure)))
        self.assertEqual(expected, self.evaluate(get_next()))

    @parameterized.named_parameters(
        ("Sequential", None),
        ("Parallel", 10),
    )
    def testShortCircuitCapturedInput(self, num_parallel_calls):
        captured_t = variables.Variable(42)
        dataset = self.structuredDataset(None).repeat().map(
            lambda x: captured_t, num_parallel_calls=num_parallel_calls)
        self.evaluate(variables.global_variables_initializer())
        get_next = self.getNext(dataset, requires_initialization=True)

        self.assertEqual(42, self.evaluate(get_next()))

    @parameterized.named_parameters(
        ("1", 1, 1),
        ("2", 10, 1),
        ("3", 10, 10),
        ("4", 100, 1),
        ("5", 100, 10),
        ("6", 100, 100),
    )
    def testSloppyInterleaveInOrder(self, num_elements, num_parallel_calls):
        dataset, coordination_events = _make_coordinated_sloppy_dataset(
            num_elements, num_parallel_calls)
        options = dataset_ops.Options()
        options.experimental_threading = threading_options.ThreadingOptions()
        options.experimental_threading.private_threadpool_size = (
            num_parallel_calls + 1)
        dataset = dataset.with_options(options)
        get_next = self.getNext(dataset, requires_initialization=True)
        for i in range(num_elements):
            coordination_events[i].set()
            self.assertEqual(i * i, self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @parameterized.named_parameters(
        ("1", 10, 10),
        ("2", 100, 10),
        ("3", 100, 100),
    )
    def testSloppyInterleaveOutOfOrder(self, num_elements, num_parallel_calls):
        dataset, coordination_events = _make_coordinated_sloppy_dataset(
            num_elements, num_parallel_calls)
        options = dataset_ops.Options()
        options.experimental_threading = threading_options.ThreadingOptions()
        options.experimental_threading.private_threadpool_size = (
            num_parallel_calls + 1)
        dataset = dataset.with_options(options)

        get_next = self.getNext(dataset, requires_initialization=True)

        elements = [x for x in range(num_elements)]
        for i in [1, 4, 7]:
            elements[i], elements[i + 1] = elements[i + 1], elements[i]

        for element in elements:
            coordination_events[element].set()
            self.assertEqual(element * element, self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @parameterized.named_parameters(
        ("Map", None),
        ("ParallelMap", 12),
    )
    def testPreserveCardinality(self, num_parallel_calls):
        def py_fn(_):
            raise StopIteration()

        dataset = dataset_ops.DatasetV2.from_tensors(0).map(
            lambda x: script_ops.py_func(py_fn, [x], dtypes.int64),
            num_parallel_calls=num_parallel_calls)
        get_next = self.getNext(dataset)
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(get_next())
Example #4
0
class MapDatasetTest(test.TestCase, parameterized.TestCase):
    def _buildMapDataset(self, components, count):
        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        return (dataset_ops.Dataset.from_tensor_slices(components).map(
            _map_fn).repeat(count))

    def testMapDataset(self):
        """Test an dataset that maps a TF function across its input elements."""
        # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
        # RepeatDataset(count).
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))
        count = array_ops.placeholder(dtypes.int64, shape=[])

        dataset = self._buildMapDataset(components, count)
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.initializer
        get_next = iterator.get_next()

        self.assertEqual([c.shape[1:] for c in components],
                         [t.shape for t in get_next])

        with self.test_session() as sess:
            # Test single-threaded access to the iterator.
            sess.run(init_op, feed_dict={count: 14})
            for _ in range(14):
                for i in range(7):
                    result = sess.run(get_next)
                    for component, result_component in zip(components, result):
                        self.assertAllEqual(component[i]**2, result_component)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Test multi-threaded access to the same iterator.
            sess.run(init_op, feed_dict={count: 18})
            results = []

            def iterator_thread():
                while True:
                    try:
                        results.append(sess.run(get_next))
                    except errors.OutOfRangeError:
                        return

            threads = [
                self.checkedThread(target=iterator_thread) for _ in range(8)
            ]
            for t in threads:
                t.start()
            for t in threads:
                t.join()

            # `results` will contain the same elements components**2
            # repeated 18 times, but in a non-deterministic order. Sort the
            # results, and assert that each element of components**2 is
            # produced 18 times.
            results.sort(key=lambda x: x[0])
            for i in range(7):
                for j in range(18):
                    for component, result_component in zip(
                            components, results[i * 18 + j]):
                        self.assertAllEqual(component[i]**2, result_component)

    def _buildParallelMapDataset(self, components, count, num_parallel_calls,
                                 output_buffer_size):
        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        return (dataset_ops.Dataset.from_tensor_slices(components).map(
            _map_fn, num_parallel_calls=num_parallel_calls).prefetch(
                output_buffer_size).repeat(count))

    def testParallelMapDataset(self):
        """Test an dataset that maps a TF function across its input elements."""
        # The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) ->
        # RepeatDataset(count).
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))
        count = array_ops.placeholder(dtypes.int64, shape=[])
        num_parallel_calls = array_ops.placeholder(dtypes.int32, shape=[])
        output_buffer_size = array_ops.placeholder(dtypes.int64, shape=[])

        dataset = self._buildParallelMapDataset(components, count,
                                                num_parallel_calls,
                                                output_buffer_size)
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.initializer
        get_next = iterator.get_next()

        self.assertEqual([c.shape[1:] for c in components],
                         [t.shape for t in get_next])

        with self.test_session() as sess:

            def do_test(num_parallel_calls_val, output_buffer_size_val):
                # Test single-threaded access to the iterator.
                sess.run(init_op,
                         feed_dict={
                             count: 14,
                             num_parallel_calls: num_parallel_calls_val,
                             output_buffer_size: output_buffer_size_val
                         })
                for _ in range(14):
                    for i in range(7):
                        result = sess.run(get_next)
                        for component, result_component in zip(
                                components, result):
                            self.assertAllEqual(component[i]**2,
                                                result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

                # Test multi-threaded access to the same iterator.
                sess.run(init_op,
                         feed_dict={
                             count: 18,
                             num_parallel_calls: num_parallel_calls_val,
                             output_buffer_size: output_buffer_size_val
                         })
                results = []

                def iterator_thread():
                    while True:
                        try:
                            results.append(sess.run(get_next))
                        except errors.OutOfRangeError:
                            return

                threads = [
                    self.checkedThread(target=iterator_thread)
                    for _ in range(64)
                ]
                for t in threads:
                    t.start()
                for t in threads:
                    t.join()

                # `results` will contain the same elements components**2
                # repeated 18 times, but in a non-deterministic order. Sort the
                # results, and assert that each element of components**2 is
                # produced 18 times.
                results.sort(key=lambda x: x[0])
                for i in range(7):
                    for j in range(18):
                        for component, result_component in zip(
                                components, results[i * 18 + j]):
                            self.assertAllEqual(component[i]**2,
                                                result_component)

            for num_parallel_calls_val, output_buffer_size_val in [(1, 1),
                                                                   (1, 2),
                                                                   (2, 2),
                                                                   (2, 4),
                                                                   (8, 8),
                                                                   (8, 16)]:
                do_test(num_parallel_calls_val, output_buffer_size_val)

    def testImplicitDisposeParallelMapDataset(self):
        # Tests whether a parallel map dataset will be cleaned up correctly when
        # the pipeline does not run it until exhaustion.
        # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
        # RepeatDataset(1000).
        components = (np.arange(1000),
                      np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
                      np.array(37.0) * np.arange(1000))

        dataset = self._buildParallelMapDataset(components, 1000, 100, 100)
        # NOTE(mrry): Also test that the prefetching thread is cancelled correctly.
        dataset = dataset.prefetch(100)
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for _ in range(3):
                sess.run(get_next)

    def testParallelMapUnspecifiedOutputSize(self):
        components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)

        dataset = (dataset_ops.Dataset.from_tensor_slices(components).map(
            lambda x: array_ops.check_numerics(x, "message"),
            num_parallel_calls=2))
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for _ in range(3):
                sess.run(get_next)

    def testParallelMapError(self):
        components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)

        dataset = (dataset_ops.Dataset.from_tensor_slices(components).map(
            lambda x: array_ops.check_numerics(x, "message"),
            num_parallel_calls=2))
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for _ in range(3):
                sess.run(get_next)
            # The 4th element is NaN, so `array_ops.check_numerics()` should fail.
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(get_next)
            sess.run(get_next)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def testPrefetchError(self):
        components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)

        dataset = (dataset_ops.Dataset.from_tensor_slices(components).map(
            lambda x: array_ops.check_numerics(x, "message")).prefetch(2))
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for _ in range(3):
                sess.run(get_next)
            # The 4th element is NaN, so `array_ops.check_numerics()` should fail.
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(get_next)
            sess.run(get_next)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def testCaptureHashTable(self):
        # NOTE(mrry): We must use the V2 variants of `HashTable`
        # etc. because these produce a `tf.resource`-typed output that is
        # compatible with the in-graph function implementation.
        default_val = -1
        keys = constant_op.constant(["brain", "salad", "surgery"])
        values = constant_op.constant([0, 1, 2], dtypes.int64)
        table = lookup_ops.HashTable(
            lookup_ops.KeyValueTensorInitializer(keys, values), default_val)

        input_sentences = dataset_ops.Dataset.from_tensor_slices(
            ["brain brain tank salad surgery", "surgery brain"])

        iterator = (input_sentences.map(
            lambda x: string_ops.string_split([x]).values).map(
                table.lookup).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(table.init)
            sess.run(init_op)
            sess.run(get_next)
            sess.run(get_next)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def testCaptureQueue(self):
        elements = np.random.randint(100, size=[200])
        queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])
        enqueue_op = queue.enqueue_many(elements)
        close_op = queue.close()
        iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1).map(
            lambda _: queue.dequeue()).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(enqueue_op)
            sess.run(close_op)
            sess.run(init_op)
            for element in elements:
                self.assertEqual(element, sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def testCaptureSameResourceMultipleTimes(self):
        elements = np.random.randint(100, size=[200])
        queue = data_flow_ops.FIFOQueue(200,
                                        dtypes.int64,
                                        shapes=[],
                                        shared_name="shared_queue")
        queue_2 = data_flow_ops.FIFOQueue(200,
                                          dtypes.int64,
                                          shapes=[],
                                          shared_name="shared_queue")

        enqueue_op = queue.enqueue_many(elements)
        close_op = queue.close()

        iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1).map(
            lambda _: (queue.dequeue(), queue_2.dequeue())).
                    make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(enqueue_op)
            sess.run(close_op)
            sess.run(init_op)
            for i in range(100):
                self.assertEqual(
                    sorted([elements[i * 2], elements[i * 2 + 1]]),
                    sorted(sess.run(get_next)))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def testCaptureVariable(self):
        counter_var = variable_scope.get_variable("counter", (),
                                                  dtypes.int32,
                                                  use_resource=True)
        iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
            lambda _: counter_var.assign_add(1)).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(counter_var.initializer)
            sess.run(init_op)
            for i in range(10):
                self.assertEqual(i, sess.run(counter_var))
                self.assertEqual(i + 1, sess.run(get_next))
            self.assertEqual(10, sess.run(counter_var))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)
            self.assertEqual(10, sess.run(counter_var))

    def testCaptureUninitializedVariableError(self):
        counter_var = variable_scope.get_variable("counter", (),
                                                  dtypes.int32,
                                                  use_resource=True)
        iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
            lambda _: counter_var.assign_add(1)).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            with self.assertRaises(errors.NotFoundError):
                sess.run(get_next)

    def testSeededStatefulOperatorIsProperlyStateful(self):
        iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
            lambda _: random_ops.random_uniform((), seed=11)).batch(
                2).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            random_values = []
            with self.assertRaises(errors.OutOfRangeError):
                while True:
                    random_values.extend(sess.run(get_next))
            self.assertEqual(10, len(random_values))
            self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6)
            sess.run(init_op)
            random_values_2 = []
            with self.assertRaises(errors.OutOfRangeError):
                while True:
                    random_values_2.extend(sess.run(get_next))

            # Randomness is repeatable given same seed
            self.assertAllClose(random_values, random_values_2)

    def testMapDict(self):
        iterator = (dataset_ops.Dataset.range(10).map(lambda x: {
            "foo": x * 2,
            "bar": x**2
        }).map(lambda d: d["foo"] + d["bar"]).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for i in range(10):
                self.assertEqual(i * 2 + i**2, sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def testMapNamedtuple(self, count=10):
        # construct dataset of tuples
        labels = dataset_ops.Dataset.range(count)
        images = labels.map(lambda l: -l)
        dataset_tuple = dataset_ops.Dataset.zip((labels, images))

        # convert dataset of tuples to dataset of namedtuples
        example = namedtuple("Example", ["label", "image"])
        dataset_namedtuple = dataset_tuple.map(example)

        def preprocess_tuple(label, image):
            image = 2 * image
            return label, image

        def preprocess_namedtuple(example):
            return example._replace(image=2 * example.image)

        # preprocess both datasets
        dataset_tuple = dataset_tuple.map(preprocess_tuple)
        dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple)

        next_tuple = dataset_tuple.make_one_shot_iterator().get_next()
        next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next(
        )

        # make sure both datasets contain the same data
        with self.test_session() as sess:
            for i in range(count):
                tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple])
                self.assertEqual(tuple_, namedtuple_)
                self.assertEqual(tuple_, (i, -2 * i))

            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_namedtuple)

    def testUseStepContainerInMap(self):
        row = np.arange(6)
        iterator = (dataset_ops.Dataset.from_tensors(
            row).map(lambda elems: functional_ops.map_fn(
                lambda x: x * x, elems)).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            self.assertAllEqual(row**2, sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def testPrefetch(self):
        # We will use this event to test that `_map_py_func()` has been
        # invoked a certain number of times (6 times, to be exact) after
        # consuming fewer elements from the iterator.
        ev = threading.Event()

        set_event_during_invocation = 5

        def _map_py_func(x):
            if x == set_event_during_invocation:
                ev.set()
            return x * x

        def _map_fn(x):
            return script_ops.py_func(_map_py_func, [x], x.dtype)

        buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
        iterator = (dataset_ops.Dataset.range(100).map(_map_fn).prefetch(
            buffer_size_placeholder).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            # Simple test that prefetch yields the expected values in the
            # expected order.
            for buffer_size in [1, 10, 100, 1000]:
                sess.run(init_op,
                         feed_dict={buffer_size_placeholder: buffer_size})
                for i in range(100):
                    self.assertEqual(i * i, sess.run(get_next))
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

            # We can indirectly observe that varying the buffer size has the
            # intended effect by observing when `ev` is set (on the 6th
            # invocation of `_map_py_func()`).
            # NOTE(mrry): We do not test with `buffer_size ==
            # set_event_during_invocation`, because we must consume at least
            # one element to start the prefetching.
            for buffer_size in range(1, set_event_during_invocation):
                event_will_be_set_after_consuming = (
                    set_event_during_invocation - buffer_size + 1)

                ev.clear()
                sess.run(init_op,
                         feed_dict={buffer_size_placeholder: buffer_size})
                for i in range(event_will_be_set_after_consuming):
                    self.assertFalse(ev.is_set())
                    self.assertEqual(i * i, sess.run(get_next))
                ev.wait()
                for i in range(event_will_be_set_after_consuming, 100):
                    self.assertEqual(i * i, sess.run(get_next))
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

    def testReturnList(self):
        iterator = (dataset_ops.Dataset.range(10).map(
            lambda x: [x, constant_op.constant(37.0)]).
                    make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for i in range(10):
                self.assertEqual((i, 37.0), sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def testMultiOutputPyFunc(self):
        # The `tf.py_func()` op returns a list of tensors for its outputs.
        def _map_fn(x_tensor):
            def _map_py_func(x):
                return x, np.array(37.0, dtype=np.float64)

            return script_ops.py_func(_map_py_func, [x_tensor],
                                      [dtypes.int64, dtypes.float64])

        iterator = (dataset_ops.Dataset.range(10).map(
            _map_fn).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for i in range(10):
                self.assertEqual((i, 37.0), sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def assertSparseValuesEqual(self, a, b):
        self.assertAllEqual(a.indices, b.indices)
        self.assertAllEqual(a.values, b.values)
        self.assertAllEqual(a.dense_shape, b.dense_shape)

    def testSparse(self):
        def _sparse(i):
            return sparse_tensor.SparseTensorValue(indices=np.array([[0, 0]]),
                                                   values=(i * np.array([1])),
                                                   dense_shape=np.array([1,
                                                                         1]))

        iterator = (dataset_ops.Dataset.range(10).map(
            _sparse).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for i in range(10):
                actual = sess.run(get_next)
                self.assertTrue(
                    isinstance(actual, sparse_tensor.SparseTensorValue))
                self.assertSparseValuesEqual(actual, _sparse(i))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def testSparseChain(self):
        def _sparse(i):
            return sparse_tensor.SparseTensorValue(indices=np.array([[0, 0]]),
                                                   values=(i * np.array([1])),
                                                   dense_shape=np.array([1,
                                                                         1]))

        def _check(i):
            self.assertTrue(sparse_tensor.is_sparse(i))
            return sparse_ops.sparse_concat(0, [i, i])

        iterator = (dataset_ops.Dataset.range(10).map(_sparse).map(
            _check).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for i in range(10):
                actual = sess.run(get_next)
                self.assertTrue(
                    isinstance(actual, sparse_tensor.SparseTensorValue))
                self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def testParallelMapOutOfRangeError(self):
        def raising_py_func(i):
            if i == 100:
                raise StopIteration()
            else:
                return i

        iterator = (dataset_ops.Dataset.range(105).map(
            lambda x: script_ops.py_func(raising_py_func, [x], dtypes.int64),
            num_parallel_calls=2).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for i in range(100):
                self.assertEqual(i, sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def testConstantOutput(self):
        iterator = (dataset_ops.Dataset.range(10).map(
            lambda x: [x, "hello", 10]).make_initializable_iterator())
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for i in range(10):
                self.assertEqual((i, b"hello", 10), sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    def testWarnOnLookupTable(self):
        def collecting_function(x):
            _ = lookup_ops.HashTable(lookup_ops.KeyValueTensorInitializer([],
                                                                          []),
                                     0.0,
                                     name="t1")
            return x

        warnings.simplefilter("always")
        with warnings.catch_warnings(record=True) as w:
            _ = dataset_ops.Dataset.range(10).map(collecting_function)
        # NOTE(mrry): Python 3 prints other warnings in addition to the one we are
        # testing, so we search for the expected warning.
        self.assertGreaterEqual(len(w), 1)
        found_warning = False
        for warning in w:
            if ("Creating lookup tables inside a function passed to Dataset.map() is "
                    "not supported." in str(warning)):
                found_warning = True
                break
        self.assertTrue(found_warning)

    def testNestedDatasetError(self):
        dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
        with self.assertRaisesRegexp(
                NotImplementedError,
                r"The Dataset.map\(\) transformation does not "
                "currently support nested datasets as outputs."):
            _ = dataset.map(dataset_ops.Dataset.from_tensor_slices)

    def testReturnValueError(self):
        dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
        with self.assertRaisesRegexp(
                TypeError, r"Unsupported return value from function passed to "
                r"Dataset.map\(\): None."):
            _ = dataset.map(lambda x: None)

    def testBrokenFunctionErrorOnInitialization(self):
        dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0, 3.0])

        def broken_function(_):
            """A function deliberately designed to fail on instantiation."""
            value = []
            tensor_value = attr_value_pb2.AttrValue()
            tensor_value.tensor.CopyFrom(
                tensor_util.make_tensor_proto(value,
                                              dtype=dtypes.float32,
                                              shape=[0],
                                              verify_shape=False))
            dtype_value = attr_value_pb2.AttrValue(
                type=dtypes.int32.as_datatype_enum)

            # Create a "Const" op with a `tf.float32` value and a `tf.int32` type
            # attr.
            const_tensor = ops.get_default_graph().create_op(
                "Const", [], [dtypes.int32],
                attrs={
                    "value": tensor_value,
                    "dtype": dtype_value
                },
                name="BrokenConst").outputs[0]
            return const_tensor

        dataset = dataset.map(broken_function)
        iterator = dataset.make_initializable_iterator()

        with self.test_session() as sess:
            with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                         "BrokenConst"):
                sess.run(iterator.initializer)

# pylint: disable=g-long-lambda

    @parameterized.named_parameters(
        ("Map", lambda dataset, func: dataset_ops.MapDataset(
            dataset, func, use_inter_op_parallelism=False)),
        ("ParallelMap", lambda dataset, func: dataset_ops.ParallelMapDataset(
            dataset,
            func,
            num_parallel_calls=1,
            use_inter_op_parallelism=False)),
    )
    def testNoInterOpParallelism(self, make_dataset_fn):
        dataset = dataset_ops.Dataset.from_tensors(0)

        def _get_tid():
            return np.int64(threading.current_thread().ident)

        def _map_fn(_):
            tids = []
            for _ in range(10):
                tids.append(script_ops.py_func(_get_tid, [], dtypes.int64))
            return tids

        dataset = make_dataset_fn(dataset, _map_fn)
        iterator = dataset.make_one_shot_iterator()
        get_next = iterator.get_next()

        with self.test_session() as sess:
            tids = sess.run(get_next)
            self.assertTrue(all(tids[0] == tid for tid in tids))
Example #5
0
def make_batched_features_dataset_multi_task(  file_pattern,
                                    batch_size,
                                    features,
                                    reader=core_readers.TFRecordDataset,
                                    label_key=None,
                                    weight_key=None,
                                    reader_args=None,
                                    num_epochs=None,
                                    shuffle=True,
                                    shuffle_buffer_size=10000,
                                    shuffle_seed=None,
                                    prefetch_buffer_size=optimization.AUTOTUNE,
                                    reader_num_threads=32,
                                    parser_num_threads=32,
                                    sloppy_ordering=True,
                                    drop_final_batch=False):

    """Returns a `Dataset` of feature dictionaries from `Example` protos.
    Returns:
    A dataset of `dict` elements, (or a tuple of `dict` elements and label).
    Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
    """
    if shuffle_seed is None:
        shuffle_seed = int(time.time())

    filenames = list(gfile.Glob(file_pattern))
    dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
    if shuffle:
        dataset = dataset.shuffle(len(filenames), shuffle_seed)

    # Read `Example` records from files as tensor objects.
    if reader_args is None:
        reader_args = []

    # Read files sequentially (if reader_num_threads=1) or in parallel
    dataset = dataset.apply(
      interleave_ops.parallel_interleave(
          lambda filename: reader(filename, *reader_args),
          cycle_length=reader_num_threads,
          block_length=200,
          sloppy=sloppy_ordering))

    # Extract values if the `Example` tensors are stored as key-value tuples.
    if dataset_ops.get_legacy_output_types(dataset) == (
          dtypes.string, dtypes.string):
        dataset = dataset_ops.MapDataset(
          dataset, lambda _, v: v, use_inter_op_parallelism=True)

    # Apply dataset repeat and shuffle transformations.
    dataset = dataset.apply(
        shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
                                       shuffle_seed))

    dataset = dataset.batch(
      batch_size, drop_remainder=drop_final_batch or num_epochs is None)

    # Parse `Example` tensors to a dictionary of `Feature` tensors.
    dataset = dataset.apply(
      parsing_ops.parse_example_dataset(
          features, num_parallel_calls=parser_num_threads))

        
    if weight_key:
        #assert label_key
        #assert label_key != weight_key
        #assert label_key in features
        assert weight_key in features
        if label_key:
            if label_key not in features:
                raise ValueError(
                    "The 'label_key' provided (%r) must be one of the 'features' keys."% label_key)
        assert label_key != weight_key
        
        
        dataset = dataset.map(lambda x: (x, tuple([x.pop(label_key)]*5),x.pop(weight_key)))
        #w = dataset.map(lambda x,y : x.pop(weight_key))
        
    else:
        if label_key:
            if label_key not in features:
                raise ValueError(
                    "The `label_key` provided (%r) must be one of the `features` keys." % label_key)
        dataset = dataset.map(lambda x: (x, tuple([x.pop(label_key)]*5)))
    dataset = dataset.prefetch(prefetch_buffer_size)
    
    if not weight_key:
        return dataset
    else:
        return dataset
Example #6
0
def sample_from_datasets_v2(datasets,
                            weights=None,
                            seed=None,
                            stop_on_empty_dataset=False):
  """Samples elements at random from the datasets in `datasets`.

  Creates a dataset by interleaving elements of `datasets` with `weight[i]`
  probability of picking an element from dataset `i`. Sampling is done without
  replacement. For example, suppose we have 2 datasets:

  ```python
  dataset1 = tf.data.Dataset.range(0, 3)
  dataset2 = tf.data.Dataset.range(100, 103)
  ```

  Suppose also that we sample from these 2 datasets with the following weights:

  ```python
  sample_dataset = tf.data.experimental.sample_from_datasets(
      [dataset1, dataset2], weights=[0.5, 0.5])
  ```

  One possible outcome of elements in sample_dataset is:

  ```
  print(list(sample_dataset.as_numpy_iterator()))
  # [100, 0, 1, 101, 2, 102]
  ```

  Args:
    datasets: A list of `tf.data.Dataset` objects with compatible structure.
    weights: (Optional.) A list of `len(datasets)` floating-point values where
      `weights[i]` represents the probability with which an element should be
      sampled from `datasets[i]`, or a `tf.data.Dataset` object where each
      element is such a list. Defaults to a uniform distribution across
      `datasets`.
    seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
      seed that will be used to create the distribution. See
      `tf.random.set_seed` for behavior.
    stop_on_empty_dataset: If `True`, sampling stops if it encounters an empty
      dataset. If `False`, it skips empty datasets. It is recommended to set it
      to `True`. Otherwise, the distribution of samples starts off as the user
      intends, but may change as input datasets become empty. This can be
      difficult to detect since the dataset starts off looking correct. Default
      to `False` for backward compatibility.

  Returns:
    A dataset that interleaves elements from `datasets` at random, according to
    `weights` if provided, otherwise with uniform probability.

  Raises:
    TypeError: If the `datasets` or `weights` arguments have the wrong type.
    ValueError: If the `weights` argument is specified and does not match the
      length of the `datasets` element.
  """
  num_datasets = len(datasets)
  if not isinstance(weights, dataset_ops.DatasetV2):
    if weights is None:
      # Select inputs with uniform probability.
      logits = [[1.0] * num_datasets]

    else:
      # Use the given `weights` as the probability of choosing the respective
      # input.
      weights = ops.convert_to_tensor(weights, name="weights")
      if weights.dtype not in (dtypes.float32, dtypes.float64):
        raise TypeError("`weights` must be convertible to a tensor of "
                        "`tf.float32` or `tf.float64` elements.")
      if not weights.shape.is_compatible_with([num_datasets]):
        raise ValueError(
            "`weights` must be a vector of length `len(datasets)`.")

      # The `stateless_multinomial()` op expects log-probabilities, as opposed
      # to weights.
      logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)

    # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
    # is a `Dataset`, it is possible that evaluating it has a side effect the
    # user depends on.
    if len(datasets) == 1:
      return datasets[0]

    def select_dataset_constant_logits(seed):
      return array_ops.squeeze(
          gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
          axis=[0, 1])

    selector_input = dataset_ops.MapDataset(
        random_ops.RandomDataset(seed).batch(2),
        select_dataset_constant_logits,
        use_inter_op_parallelism=False)

  else:
    # Use each element of the given `weights` dataset as the probability of
    # choosing the respective input.
    #
    # The `stateless_multinomial()` op expects log-probabilities, as opposed to
    # weights.
    logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))

    def select_dataset_varying_logits(logits, seed):
      return array_ops.squeeze(
          gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
          axis=[0, 1])

    logits_and_seeds = dataset_ops.Dataset.zip(
        (logits_ds, random_ops.RandomDataset(seed).batch(2)))
    selector_input = dataset_ops.MapDataset(
        logits_and_seeds,
        select_dataset_varying_logits,
        use_inter_op_parallelism=False)

  return _DirectedInterleaveDataset(selector_input, datasets,
                                    stop_on_empty_dataset)
Example #7
0
def sample_from_datasets_v2(datasets, weights=None, seed=None):
    """Samples elements at random from the datasets in `datasets`.

  Args:
    datasets: A list of `tf.data.Dataset` objects with compatible structure.
    weights: (Optional.) A list of `len(datasets)` floating-point values where
      `weights[i]` represents the probability with which an element should be
      sampled from `datasets[i]`, or a `tf.data.Dataset` object where each
      element is such a list. Defaults to a uniform distribution across
      `datasets`.
    seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
      random seed that will be used to create the distribution. See
      `tf.set_random_seed` for behavior.

  Returns:
    A dataset that interleaves elements from `datasets` at random, according to
    `weights` if provided, otherwise with uniform probability.

  Raises:
    TypeError: If the `datasets` or `weights` arguments have the wrong type.
    ValueError: If the `weights` argument is specified and does not match the
      length of the `datasets` element.
  """
    num_datasets = len(datasets)
    if not isinstance(weights, dataset_ops.Dataset):
        if weights is None:
            # Select inputs with uniform probability.
            logits = [[1.0] * num_datasets]

        else:
            # Use the given `weights` as the probability of choosing the respective
            # input.
            weights = ops.convert_to_tensor(weights, name="weights")
            if weights.dtype not in (dtypes.float32, dtypes.float64):
                raise TypeError("`weights` must be convertible to a tensor of "
                                "`tf.float32` or `tf.float64` elements.")
            if not weights.shape.is_compatible_with([num_datasets]):
                raise ValueError(
                    "`weights` must be a vector of length `len(datasets)`.")

            # The `stateless_multinomial()` op expects log-probabilities, as opposed
            # to weights.
            logits = array_ops.expand_dims(
                math_ops.log(weights, name="logits"), 0)

        # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
        # is a `Dataset`, it is possible that evaluating it has a side effect the
        # user depends on.
        if len(datasets) == 1:
            return datasets[0]

        def select_dataset_constant_logits(seed):
            return array_ops.squeeze(
                gen_stateless_random_ops.stateless_multinomial(logits,
                                                               1,
                                                               seed=seed),
                axis=[0, 1])

        selector_input = dataset_ops.MapDataset(
            random_ops.RandomDataset(seed).batch(2),
            select_dataset_constant_logits,
            use_inter_op_parallelism=False)

    else:
        # Use each element of the given `weights` dataset as the probability of
        # choosing the respective input.

        # The `stateless_multinomial()` op expects log-probabilities, as opposed to
        # weights.
        logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))

        def select_dataset_varying_logits(logits, seed):
            return array_ops.squeeze(
                gen_stateless_random_ops.stateless_multinomial(logits,
                                                               1,
                                                               seed=seed),
                axis=[0, 1])

        logits_and_seeds = dataset_ops.Dataset.zip(
            (logits_ds, random_ops.RandomDataset(seed).batch(2)))
        selector_input = dataset_ops.MapDataset(logits_and_seeds,
                                                select_dataset_varying_logits,
                                                use_inter_op_parallelism=False)

    return _DirectedInterleaveDataset(selector_input, datasets)