Example #1
0
    def _init_func():
      """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
      ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
      resource = gen_dataset_ops.anonymous_iterator(
          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
      with ops.control_dependencies(
          [gen_dataset_ops.make_iterator(ds_variant, resource)]):
        return gen_dataset_ops.iterator_to_string_handle(resource)
  def testBasic(self):
    ds = dataset_ops.Dataset.range(100)
    ds_variant = ds._variant_tensor  # pylint: disable=protected-access

    wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)
    unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant)

    variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                             ds._element_structure)
    get_next = self.getNext(variant_ds, requires_initialization=True)
    for i in range(100):
      self.assertEqual(i, self.evaluate(get_next()))
Example #3
0
    def _init_func():
      """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
      ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
      resource = gen_dataset_ops.anonymous_iterator(
          **dataset_ops.flat_structure(self._input_dataset))
      with ops.control_dependencies(
          [gen_dataset_ops.make_iterator(ds_variant, resource)]):
        return gen_dataset_ops.iterator_to_string_handle(resource)
  def DISABLED_testBasic(self):
    ds = dataset_ops.Dataset.range(100)
    ds_variant = ds._variant_tensor  # pylint: disable=protected-access

    wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)
    unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant)

    variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                             ds.element_spec)
    get_next = self.getNext(variant_ds, requires_initialization=True)
    for i in range(100):
      self.assertEqual(i, self.evaluate(get_next()))
Example #5
0
        def _init_func():
            """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
            ds_variant = gen_dataset_ops.unwrap_dataset_variant(
                wrap_ds_variant)
            resource = gen_dataset_ops.anonymous_iterator(
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)
            with ops.control_dependencies(
                [gen_dataset_ops.make_iterator(ds_variant, resource)]):
                return gen_dataset_ops.iterator_to_string_handle(resource)
    def testBasic(self):
        ds = dataset_ops.Dataset.range(100)
        ds_variant = ds._as_variant_tensor()  # pylint: disable=protected-access

        wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)
        unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
            wrapped_variant)

        variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                                 ds._element_structure)
        iterator = dataset_ops.make_initializable_iterator(variant_ds)
        get_next = iterator.get_next()

        with self.cached_session():
            self.evaluate(iterator.initializer)
            for i in range(100):
                self.assertEqual(i, self.evaluate(get_next))
  def testSkipEagerGPU(self):
    ds = dataset_ops.Dataset.range(100)
    ds_variant = ds._variant_tensor  # pylint: disable=protected-access
    wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)

    with ops.device("/gpu:0"):
      gpu_wrapped_variant = array_ops.identity(wrapped_variant)

    unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
        gpu_wrapped_variant)
    variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                             ds._element_structure)
    iterator = dataset_ops.make_initializable_iterator(variant_ds)
    get_next = iterator.get_next()

    with self.cached_session():
      self.evaluate(iterator.initializer)
      for i in range(100):
        self.assertEqual(i, self.evaluate(get_next))
Example #8
0
    def testGPU(self):
        ds = dataset_ops.Dataset.range(100)
        ds_variant = ds._variant_tensor  # pylint: disable=protected-access
        wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)

        with ops.device("/gpu:0"):
            gpu_wrapped_variant = array_ops.identity(wrapped_variant)

        unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
            gpu_wrapped_variant)
        variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                                 ds.element_spec)
        iterator = dataset_ops.make_initializable_iterator(variant_ds)
        get_next = iterator.get_next()

        with self.cached_session():
            self.evaluate(iterator.initializer)
            for i in range(100):
                self.assertEqual(i, self.evaluate(get_next))