def testMultipleVariantTensors(self): ds = dataset_ops.Dataset.range(10) ds = _TestDataset(ds) variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) self.assertSetEqual( set(["RangeDataset", "ModelDataset", "PrefetchDataset"]), set(x.name for x in variant_tensor_ops))
def _clone_dataset(dataset): """Returns a cloned version of `dataset`.""" variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset) remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops) new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0] return dataset_ops._VariantDataset(new_variant_tensor, dataset._element_structure)
def testMultipleVariantTensors(self): ds = dataset_ops.Dataset.range(10) ds = _TestDataset(ds) variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) self.assertSetEqual( set(["RangeDataset", "ModelDataset", "PrefetchDataset"]), set([x.name for x in variant_tensor_ops]))
def _clone_dataset(dataset): """Returns a cloned version of `dataset`.""" variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset) remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops) new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0] return dataset_ops._VariantDataset(new_variant_tensor, dataset.element_spec)
def testZip(self): ds1 = dataset_ops.Dataset.range(10) ds2 = dataset_ops.Dataset.range(10) ds = dataset_ops.Dataset.zip((ds1, ds2)) variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) self.assertSetEqual( set(["ZipDataset", "RangeDataset", "RangeDataset_1"]), set(x.name for x in variant_tensor_ops))
def testConcat(self): ds1 = dataset_ops.Dataset.range(10) ds2 = dataset_ops.Dataset.range(10) ds = ds1.concatenate(ds2) variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) self.assertSetEqual( set(["ConcatenateDataset", "RangeDataset", "RangeDataset_1"]), set(x.name for x in variant_tensor_ops))
def testFlatMap(self): ds1 = dataset_ops.Dataset.range(10).repeat(10) def map_fn(ds): def _map(x): return ds.batch(x) return _map ds2 = dataset_ops.Dataset.range(20).prefetch(1) ds2 = ds2.flat_map(map_fn(ds1)) variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds2) self.assertSetEqual( set([ "FlatMapDataset", "PrefetchDataset", "RepeatDataset", "RangeDataset", "RangeDataset_1" ]), set(x.name for x in variant_tensor_ops))
def testSimplePipeline(self): ds = dataset_ops.Dataset.range(10).map(math_ops.square) variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) self.assertSetEqual(set(["MapDataset", "RangeDataset"]), set(x.name for x in variant_tensor_ops))
def testOnlySource(self): ds = dataset_ops.Dataset.range(10) variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) self.assertAllEqual(["RangeDataset"], [x.name for x in variant_tensor_ops])