def _clone_dataset(dataset): """Returns a cloned version of `dataset`.""" variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset) remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops) new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0] return dataset_ops._VariantDataset(new_variant_tensor, dataset.element_spec)
def _clone_dataset(dataset): """Returns a cloned version of `dataset`.""" variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset) remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops) new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0] return dataset_ops._VariantDataset(new_variant_tensor, dataset._element_structure)
def testAsFunctionWithMap(self): with ops.device("CPU"): original_dataset = dataset_ops.Dataset.range(5).map( lambda x: x * 2) fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = dataset_ops._VariantDataset( variant, original_dataset.element_spec) self.assertDatasetProduces(revived_dataset, range(0, 10, 2))
def testAsFunctionWithMap(self): if not context.executing_eagerly(): self.skipTest("Only works executing eagerly") with ops.device("CPU"): original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2) fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = dataset_ops._VariantDataset( variant, original_dataset.element_spec) self.assertDatasetProduces(revived_dataset, range(0, 10, 2))
def testBasic(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._variant_tensor # pylint: disable=protected-access wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant) unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant) variant_ds = dataset_ops._VariantDataset(unwrapped_variant, ds._element_structure) get_next = self.getNext(variant_ds, requires_initialization=True) for i in range(100): self.assertEqual(i, self.evaluate(get_next()))
def DISABLED_testBasic(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._variant_tensor # pylint: disable=protected-access wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant) unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant) variant_ds = dataset_ops._VariantDataset(unwrapped_variant, ds.element_spec) get_next = self.getNext(variant_ds, requires_initialization=True) for i in range(100): self.assertEqual(i, self.evaluate(get_next()))
def testAsFunctionFromReader(self): with ops.device("CPU"): file_path = os.path.join( self.get_temp_dir(), "{}.tfrecord.gz".format("tf_record_asset")) with tf_record.TFRecordWriter(file_path, "GZIP") as f: for v in ["a", "aa", "aaa"]: f.write(str(v)) original_dataset = readers.TFRecordDataset([file_path], compression_type="GZIP") fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = dataset_ops._VariantDataset( variant, original_dataset.element_spec) self.assertDatasetProduces(revived_dataset, ["a", "aa", "aaa"])
def testBasic(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._as_variant_tensor() # pylint: disable=protected-access wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant) unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant( wrapped_variant) variant_ds = dataset_ops._VariantDataset(unwrapped_variant, ds._element_structure) iterator = dataset_ops.make_initializable_iterator(variant_ds) get_next = iterator.get_next() with self.cached_session(): self.evaluate(iterator.initializer) for i in range(100): self.assertEqual(i, self.evaluate(get_next))
def testSkipEagerGPU(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._variant_tensor # pylint: disable=protected-access wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant) with ops.device("/gpu:0"): gpu_wrapped_variant = array_ops.identity(wrapped_variant) unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant( gpu_wrapped_variant) variant_ds = dataset_ops._VariantDataset(unwrapped_variant, ds._element_structure) iterator = dataset_ops.make_initializable_iterator(variant_ds) get_next = iterator.get_next() with self.cached_session(): self.evaluate(iterator.initializer) for i in range(100): self.assertEqual(i, self.evaluate(get_next))
def testGPU(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._variant_tensor # pylint: disable=protected-access wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant) with ops.device("/gpu:0"): gpu_wrapped_variant = array_ops.identity(wrapped_variant) unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant( gpu_wrapped_variant) variant_ds = dataset_ops._VariantDataset(unwrapped_variant, ds.element_spec) iterator = dataset_ops.make_initializable_iterator(variant_ds) get_next = iterator.get_next() with self.cached_session(): self.evaluate(iterator.initializer) for i in range(100): self.assertEqual(i, self.evaluate(get_next))