Example #1
0
 def testMultipleVariantTensors(self):
     ds = dataset_ops.Dataset.range(10)
     ds = _TestDataset(ds)
     variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
     self.assertSetEqual(
         set(["RangeDataset", "ModelDataset", "PrefetchDataset"]),
         set(x.name for x in variant_tensor_ops))
Example #2
0
def _clone_dataset(dataset):
  """Returns a cloned version of `dataset`."""
  variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset)
  remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops)
  new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0]
  return dataset_ops._VariantDataset(new_variant_tensor,
                                     dataset._element_structure)
Example #3
0
 def testMultipleVariantTensors(self):
   ds = dataset_ops.Dataset.range(10)
   ds = _TestDataset(ds)
   variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
   self.assertSetEqual(
       set(["RangeDataset", "ModelDataset", "PrefetchDataset"]),
       set([x.name for x in variant_tensor_ops]))
Example #4
0
def _clone_dataset(dataset):
    """Returns a cloned version of `dataset`."""
    variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset)
    remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops)
    new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0]
    return dataset_ops._VariantDataset(new_variant_tensor,
                                       dataset.element_spec)
Example #5
0
 def testZip(self):
     ds1 = dataset_ops.Dataset.range(10)
     ds2 = dataset_ops.Dataset.range(10)
     ds = dataset_ops.Dataset.zip((ds1, ds2))
     variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
     self.assertSetEqual(
         set(["ZipDataset", "RangeDataset", "RangeDataset_1"]),
         set(x.name for x in variant_tensor_ops))
Example #6
0
 def testConcat(self):
     ds1 = dataset_ops.Dataset.range(10)
     ds2 = dataset_ops.Dataset.range(10)
     ds = ds1.concatenate(ds2)
     variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
     self.assertSetEqual(
         set(["ConcatenateDataset", "RangeDataset", "RangeDataset_1"]),
         set(x.name for x in variant_tensor_ops))
Example #7
0
    def testFlatMap(self):
        ds1 = dataset_ops.Dataset.range(10).repeat(10)

        def map_fn(ds):
            def _map(x):
                return ds.batch(x)

            return _map

        ds2 = dataset_ops.Dataset.range(20).prefetch(1)
        ds2 = ds2.flat_map(map_fn(ds1))
        variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds2)
        self.assertSetEqual(
            set([
                "FlatMapDataset", "PrefetchDataset", "RepeatDataset",
                "RangeDataset", "RangeDataset_1"
            ]), set(x.name for x in variant_tensor_ops))
Example #8
0
 def testSimplePipeline(self):
     ds = dataset_ops.Dataset.range(10).map(math_ops.square)
     variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
     self.assertSetEqual(set(["MapDataset", "RangeDataset"]),
                         set(x.name for x in variant_tensor_ops))
Example #9
0
 def testOnlySource(self):
     ds = dataset_ops.Dataset.range(10)
     variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
     self.assertAllEqual(["RangeDataset"],
                         [x.name for x in variant_tensor_ops])