Exemplo n.º 1
0
def _clone_dataset(dataset):
    """Returns a cloned version of `dataset`."""
    variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset)
    remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops)
    new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0]
    return dataset_ops._VariantDataset(new_variant_tensor,
                                       dataset.element_spec)
Exemplo n.º 2
0
def _clone_dataset(dataset):
  """Returns a cloned version of `dataset`."""
  variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset)
  remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops)
  new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0]
  return dataset_ops._VariantDataset(new_variant_tensor,
                                     dataset._element_structure)
Exemplo n.º 3
0
    def testAsFunctionWithMap(self):
        with ops.device("CPU"):
            original_dataset = dataset_ops.Dataset.range(5).map(
                lambda x: x * 2)
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = dataset_ops._VariantDataset(
                variant, original_dataset.element_spec)
            self.assertDatasetProduces(revived_dataset, range(0, 10, 2))
Exemplo n.º 4
0
  def testAsFunctionWithMap(self):
    if not context.executing_eagerly():
      self.skipTest("Only works executing eagerly")
    with ops.device("CPU"):
      original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
      fn = original_dataset._trace_variant_creation()
      variant = fn()

      revived_dataset = dataset_ops._VariantDataset(
          variant, original_dataset.element_spec)
      self.assertDatasetProduces(revived_dataset, range(0, 10, 2))
Exemplo n.º 5
0
  def testBasic(self):
    ds = dataset_ops.Dataset.range(100)
    ds_variant = ds._variant_tensor  # pylint: disable=protected-access

    wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)
    unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant)

    variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                             ds._element_structure)
    get_next = self.getNext(variant_ds, requires_initialization=True)
    for i in range(100):
      self.assertEqual(i, self.evaluate(get_next()))
  def DISABLED_testBasic(self):
    ds = dataset_ops.Dataset.range(100)
    ds_variant = ds._variant_tensor  # pylint: disable=protected-access

    wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)
    unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant)

    variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                             ds.element_spec)
    get_next = self.getNext(variant_ds, requires_initialization=True)
    for i in range(100):
      self.assertEqual(i, self.evaluate(get_next()))
Exemplo n.º 7
0
    def testAsFunctionFromReader(self):
        with ops.device("CPU"):
            file_path = os.path.join(
                self.get_temp_dir(),
                "{}.tfrecord.gz".format("tf_record_asset"))
            with tf_record.TFRecordWriter(file_path, "GZIP") as f:
                for v in ["a", "aa", "aaa"]:
                    f.write(str(v))
            original_dataset = readers.TFRecordDataset([file_path],
                                                       compression_type="GZIP")
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = dataset_ops._VariantDataset(
                variant, original_dataset.element_spec)
            self.assertDatasetProduces(revived_dataset, ["a", "aa", "aaa"])
Exemplo n.º 8
0
    def testBasic(self):
        ds = dataset_ops.Dataset.range(100)
        ds_variant = ds._as_variant_tensor()  # pylint: disable=protected-access

        wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)
        unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
            wrapped_variant)

        variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                                 ds._element_structure)
        iterator = dataset_ops.make_initializable_iterator(variant_ds)
        get_next = iterator.get_next()

        with self.cached_session():
            self.evaluate(iterator.initializer)
            for i in range(100):
                self.assertEqual(i, self.evaluate(get_next))
Exemplo n.º 9
0
  def testSkipEagerGPU(self):
    ds = dataset_ops.Dataset.range(100)
    ds_variant = ds._variant_tensor  # pylint: disable=protected-access
    wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)

    with ops.device("/gpu:0"):
      gpu_wrapped_variant = array_ops.identity(wrapped_variant)

    unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
        gpu_wrapped_variant)
    variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                             ds._element_structure)
    iterator = dataset_ops.make_initializable_iterator(variant_ds)
    get_next = iterator.get_next()

    with self.cached_session():
      self.evaluate(iterator.initializer)
      for i in range(100):
        self.assertEqual(i, self.evaluate(get_next))
Exemplo n.º 10
0
    def testGPU(self):
        ds = dataset_ops.Dataset.range(100)
        ds_variant = ds._variant_tensor  # pylint: disable=protected-access
        wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)

        with ops.device("/gpu:0"):
            gpu_wrapped_variant = array_ops.identity(wrapped_variant)

        unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
            gpu_wrapped_variant)
        variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                                 ds.element_spec)
        iterator = dataset_ops.make_initializable_iterator(variant_ds)
        get_next = iterator.get_next()

        with self.cached_session():
            self.evaluate(iterator.initializer)
            for i in range(100):
                self.assertEqual(i, self.evaluate(get_next))