def _init_func(): """Creates an iterator for the input dataset. Returns: A `string` tensor that encapsulates the iterator created. """ ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant) resource = gen_dataset_ops.anonymous_iterator( **self._input_dataset._flat_structure) # pylint: disable=protected-access with ops.control_dependencies( [gen_dataset_ops.make_iterator(ds_variant, resource)]): return gen_dataset_ops.iterator_to_string_handle(resource)
def testBasic(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._variant_tensor # pylint: disable=protected-access wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant) unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant) variant_ds = dataset_ops._VariantDataset(unwrapped_variant, ds._element_structure) get_next = self.getNext(variant_ds, requires_initialization=True) for i in range(100): self.assertEqual(i, self.evaluate(get_next()))
def _init_func(): """Creates an iterator for the input dataset. Returns: A `string` tensor that encapsulates the iterator created. """ ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant) resource = gen_dataset_ops.anonymous_iterator( **dataset_ops.flat_structure(self._input_dataset)) with ops.control_dependencies( [gen_dataset_ops.make_iterator(ds_variant, resource)]): return gen_dataset_ops.iterator_to_string_handle(resource)
def DISABLED_testBasic(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._variant_tensor # pylint: disable=protected-access wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant) unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant) variant_ds = dataset_ops._VariantDataset(unwrapped_variant, ds.element_spec) get_next = self.getNext(variant_ds, requires_initialization=True) for i in range(100): self.assertEqual(i, self.evaluate(get_next()))
def _init_func(): """Creates an iterator for the input dataset. Returns: A `string` tensor that encapsulates the iterator created. """ ds_variant = gen_dataset_ops.unwrap_dataset_variant( wrap_ds_variant) resource = gen_dataset_ops.anonymous_iterator( output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) with ops.control_dependencies( [gen_dataset_ops.make_iterator(ds_variant, resource)]): return gen_dataset_ops.iterator_to_string_handle(resource)
def testBasic(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._as_variant_tensor() # pylint: disable=protected-access wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant) unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant( wrapped_variant) variant_ds = dataset_ops._VariantDataset(unwrapped_variant, ds._element_structure) iterator = dataset_ops.make_initializable_iterator(variant_ds) get_next = iterator.get_next() with self.cached_session(): self.evaluate(iterator.initializer) for i in range(100): self.assertEqual(i, self.evaluate(get_next))
def testSkipEagerGPU(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._variant_tensor # pylint: disable=protected-access wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant) with ops.device("/gpu:0"): gpu_wrapped_variant = array_ops.identity(wrapped_variant) unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant( gpu_wrapped_variant) variant_ds = dataset_ops._VariantDataset(unwrapped_variant, ds._element_structure) iterator = dataset_ops.make_initializable_iterator(variant_ds) get_next = iterator.get_next() with self.cached_session(): self.evaluate(iterator.initializer) for i in range(100): self.assertEqual(i, self.evaluate(get_next))
def testGPU(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._variant_tensor # pylint: disable=protected-access wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant) with ops.device("/gpu:0"): gpu_wrapped_variant = array_ops.identity(wrapped_variant) unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant( gpu_wrapped_variant) variant_ds = dataset_ops._VariantDataset(unwrapped_variant, ds.element_spec) iterator = dataset_ops.make_initializable_iterator(variant_ds) get_next = iterator.get_next() with self.cached_session(): self.evaluate(iterator.initializer) for i in range(100): self.assertEqual(i, self.evaluate(get_next))