def testAllowStatefulOp(self):
        with compat.forward_compatibility_horizon(2019, 9, 12):
            with ops.device(self._device0):
                dataset0 = dataset_ops.Dataset.range(100).map(
                    lambda _: random_ops.random_uniform(  # pylint:disable=g-long-lambda
                        [],
                        minval=1,
                        maxval=10,
                        dtype=dtypes.float32))
                opt = dataset_ops.Options()
                opt.experimental_allow_stateful = True
                dataset0 = dataset0.with_options(opt)
            replicated_ds = distribute.replicate(
                dataset0, [self._device1, self._device2])
            dataset1 = replicated_ds[self._device1]
            dataset2 = replicated_ds[self._device2]

            with ops.device(self._device0):
                get_next0 = self.getNext(dataset0)
            with ops.device(self._device1):
                get_next1 = self.getNext(dataset1)
            with ops.device(self._device2):
                get_next2 = self.getNext(dataset2)

            for _ in range(100):
                get_next0()
                get_next1()
                get_next2()
Ejemplo n.º 2
0
    def testWhitelistStatefulOp(self):
        with compat.forward_compatibility_horizon(2019, 9, 12):
            with ops.device(self._device0):
                dataset0 = dataset_ops.Dataset.range(100).map(
                    lambda _: random_ops.random_uniform(  # pylint:disable=g-long-lambda
                        [],
                        minval=1,
                        maxval=10,
                        dtype=dtypes.float32))
                opt = dataset_ops.Options()
                opt.experimental_stateful_whitelist = ["RandomUniform"]
                dataset0 = dataset0.with_options(opt)
            replicated_ds = distribute.replicate(
                dataset0, [self._device1, self._device2])
            dataset1 = replicated_ds[self._device1]
            dataset2 = replicated_ds[self._device2]

            with ops.device(self._device0):
                get_next0 = self.getNext(dataset0)
            with ops.device(self._device1):
                get_next1 = self.getNext(dataset1)
            with ops.device(self._device2):
                get_next2 = self.getNext(dataset2)

            with session.Session(self._target) as sess:
                for _ in range(100):
                    sess.run(get_next0())
                    sess.run(get_next1())
                    sess.run(get_next2())
Ejemplo n.º 3
0
    def testExternalStatePolicyIgnore(self):
        with ops.device(self._device0):
            dataset0 = dataset_ops.Dataset.range(100).map(
                lambda _: random_ops.random_uniform(  # pylint:disable=g-long-lambda
                    [],
                    minval=1,
                    maxval=10,
                    dtype=dtypes.float32))
            opt = dataset_ops.Options()
            opt.experimental_external_state_policy = (
                distribute_options.ExternalStatePolicy.IGNORE)
            dataset0 = dataset0.with_options(opt)
        replicated_ds = distribute.replicate(dataset0,
                                             [self._device1, self._device2])
        dataset1 = replicated_ds[self._device1]
        dataset2 = replicated_ds[self._device2]

        with ops.device(self._device0):
            get_next0 = self.getNext(dataset0)
        with ops.device(self._device1):
            get_next1 = self.getNext(dataset1)
        with ops.device(self._device2):
            get_next2 = self.getNext(dataset2)

        for _ in range(100):
            self.evaluate(get_next0())
            self.evaluate(get_next1())
            self.evaluate(get_next2())
Ejemplo n.º 4
0
  def testExternalStatePolicyFail(self):
    with compat.forward_compatibility_horizon(2019, 11, 30):
      with ops.device(self._device0):
        dataset0 = dataset_ops.Dataset.range(100).map(
            lambda _: random_ops.random_uniform(  # pylint:disable=g-long-lambda
                [],
                minval=1,
                maxval=10,
                dtype=dtypes.float32))
        opt = dataset_ops.Options()
        opt.experimental_external_state_policy = (
            dataset_ops.ExternalStatePolicy.FAIL)
        dataset0 = dataset0.with_options(opt)
      with self.assertRaises(errors.FailedPreconditionError):
        replicated_ds = distribute.replicate(dataset0,
                                             [self._device1, self._device2])
        dataset1 = replicated_ds[self._device1]
        dataset2 = replicated_ds[self._device2]

        with ops.device(self._device0):
          get_next0 = self.getNext(dataset0)
        with ops.device(self._device1):
          get_next1 = self.getNext(dataset1)
        with ops.device(self._device2):
          get_next2 = self.getNext(dataset2)

        for _ in range(100):
          self.evaluate(get_next0())
          self.evaluate(get_next1())
          self.evaluate(get_next2())
Ejemplo n.º 5
0
 def testVariableInput(self):
     with ops.device(self._device0):
         counter_var = variable_scope.get_variable("counter", (),
                                                   dtypes.int32,
                                                   use_resource=True)
         dataset0 = dataset_ops.Dataset.range(100).map(
             lambda _: counter_var.assign_add(1))
     replicated_ds = distribute.replicate(dataset0,
                                          [self._device1, self._device2])
     dataset1 = replicated_ds[self._device1]
     dataset2 = replicated_ds[self._device2]
     self.evaluate(counter_var.initializer)
     with ops.device(self._device1):
         self.assertDatasetProduces(dataset1,
                                    range(1, 101),
                                    requires_initialization=True)
     with ops.device(self._device2):
         self.assertDatasetProduces(dataset2,
                                    range(1, 101),
                                    requires_initialization=True)
     # Iterate through the original device last so that replication happens
     # before counter_var is modified. The order only matters in graph mode.
     with ops.device(self._device0):
         self.assertDatasetProduces(dataset0,
                                    range(1, 101),
                                    requires_initialization=True)
Ejemplo n.º 6
0
 def testVariableInput(self):
   with ops.device(self._device0):
     counter_var = variable_scope.get_variable(
         "counter", (), dtypes.int32, use_resource=True)
     dataset0 = dataset_ops.Dataset.range(100).map(
         lambda _: counter_var.assign_add(1))
   # We don't support stateful ops in functions as of now.
   with self.assertRaises(errors.FailedPreconditionError):
     replicated_ds = distribute.replicate(dataset0,
                                          [self._device1, self._device2])
     self.evaluate(replicated_ds[self._device1]._variant_tensor)
 def testMap(self):
     with ops.device(self._device0):
         dataset0 = dataset_ops.Dataset.range(100).map(lambda x: x * 2)
     replicated_ds = distribute.replicate(dataset0,
                                          [self._device1, self._device2])
     dataset1 = replicated_ds[self._device1]
     dataset2 = replicated_ds[self._device2]
     with ops.device(self._device0):
         self.assertDatasetProduces(dataset0, range(0, 200, 2))
     with ops.device(self._device1):
         self.assertDatasetProduces(dataset1, range(0, 200, 2))
     with ops.device(self._device2):
         self.assertDatasetProduces(dataset2, range(0, 200, 2))
 def testBasic(self):
     with ops.device(self._device0):
         dataset0 = dataset_ops.Dataset.range(100)
     replicated_ds = distribute.replicate(dataset0,
                                          [self._device1, self._device2])
     dataset1 = replicated_ds[self._device1]
     dataset2 = replicated_ds[self._device2]
     with ops.device(self._device0):
         self.assertDatasetProduces(dataset0, range(100))
     with ops.device(self._device1):
         self.assertDatasetProduces(dataset1, range(100))
     with ops.device(self._device2):
         self.assertDatasetProduces(dataset2, range(100))
Ejemplo n.º 9
0
 def testVariableInput(self):
   with ops.device(self._device0):
     counter_var = variable_scope.get_variable(
         "counter", (), dtypes.int32, use_resource=True)
     dataset0 = dataset_ops.Dataset.range(100).map(
         lambda _: counter_var.assign_add(1))
   # We don't support stateful ops in functions as of now.
   with self.assertRaises(errors.InvalidArgumentError):
     replicated_ds = distribute.replicate(dataset0,
                                          [self._device1, self._device2])
     dataset1 = replicated_ds[self._device1]
     with ops.device(self._device1):
       self.assertDatasetProduces(
           dataset1, range(100), requires_initialization=True)
Ejemplo n.º 10
0
 def testVariableInput(self):
   with ops.device(self._device0):
     counter_var = variable_scope.get_variable(
         "counter", (), dtypes.int32, use_resource=True)
     dataset0 = dataset_ops.Dataset.range(100).map(
         lambda _: counter_var.assign_add(1))
   replicated_ds = distribute.replicate(dataset0,
                                        [self._device1, self._device2])
   dataset1 = replicated_ds[self._device1]
   with ops.device(self._device1):
     it1 = dataset_ops.make_initializable_iterator(dataset1)
   # We don't support stateful ops in functions as of now.
   with session.Session(self._target) as sess:
     with self.assertRaises(errors.FailedPreconditionError):
       sess.run(it1.initializer)
Ejemplo n.º 11
0
    def testFromTensorSlicesWithDataset(self):
        with ops.device(self._device0):
            dataset0 = dataset_ops.Dataset.range(100)
            dataset0 = dataset_ops.Dataset.from_tensor_slices([dataset0])
            dataset0 = dataset0.flat_map(lambda x: x)
        replicated_ds = distribute.replicate(dataset0,
                                             [self._device1, self._device2])
        dataset1 = replicated_ds[self._device1]
        dataset2 = replicated_ds[self._device2]

        with ops.device(self._device0):
            self.assertDatasetProduces(dataset0, range(100))
        with ops.device(self._device1):
            self.assertDatasetProduces(dataset1, range(100))
        with ops.device(self._device2):
            self.assertDatasetProduces(dataset2, range(100))
Ejemplo n.º 12
0
 def testVariableInput(self):
   with ops.device(self._device0):
     counter_var = variable_scope.get_variable(
         "counter", (), dtypes.int32, use_resource=True)
     dataset0 = dataset_ops.Dataset.range(100).map(
         lambda _: counter_var.assign_add(1))
   # We don't support stateful ops across processes in functions as of now.
   with self.assertRaises(errors.InvalidArgumentError):
     replicated_ds = distribute.replicate(dataset0,
                                          [self._device1, self._device2])
     dataset1 = replicated_ds[self._device1]
     with ops.device(self._device0):
       get_next0 = self.getNext(dataset0)
     with ops.device(self._device1):
       get_next1 = self.getNext(dataset1)
     for _ in range(100):
       self.evaluate(get_next0())
       self.evaluate(get_next1())
  def testMap(self):
    with ops.device(self._device0):
      dataset0 = dataset_ops.Dataset.range(100).map(lambda x: x * 2)
    replicated_ds = distribute.replicate(dataset0,
                                         [self._device1, self._device2])
    dataset1 = replicated_ds[self._device1]
    dataset2 = replicated_ds[self._device2]
    with ops.device(self._device0):
      get_next = self.getNext(dataset0)
    with ops.device(self._device1):
      get_next1 = self.getNext(dataset1)
    with ops.device(self._device2):
      get_next2 = self.getNext(dataset2)

    with session.Session(self._target) as sess:
      for i in range(100):
        self.assertEqual(i * 2, sess.run(get_next()))
        self.assertEqual(i * 2, sess.run(get_next1()))
        self.assertEqual(i * 2, sess.run(get_next2()))
Ejemplo n.º 14
0
 def testVariableInput(self):
   with ops.device(self._device0):
     counter_var = variable_scope.get_variable(
         "counter", (), dtypes.int32, use_resource=True)
     dataset0 = dataset_ops.Dataset.range(100).map(
         lambda _: counter_var.assign_add(1))
   replicated_ds = distribute.replicate(dataset0,
                                        [self._device1, self._device2])
   dataset1 = replicated_ds[self._device1]
   dataset2 = replicated_ds[self._device2]
   self.evaluate(counter_var.initializer)
   with ops.device(self._device0):
     self.assertDatasetProduces(
         dataset0, range(1, 101), requires_initialization=True)
   with ops.device(self._device1):
     self.assertDatasetProduces(
         dataset1, range(101, 201), requires_initialization=True)
   with ops.device(self._device2):
     self.assertDatasetProduces(
         dataset2, range(201, 301), requires_initialization=True)
Ejemplo n.º 15
0
 def _testVariableInput(self):
     with ops.device(self._device0):
         counter_var = variable_scope.get_variable("counter", (),
                                                   dtypes.int32,
                                                   use_resource=True)
         dataset0 = dataset_ops.Dataset.range(100).map(
             lambda _: counter_var.assign_add(1))
     with self.assertRaises(errors.InvalidArgumentError):
         replicated_ds = distribute.replicate(
             dataset0, [self._device1, self._device2])
         dataset1 = replicated_ds[self._device1]
         dataset2 = replicated_ds[self._device2]
         with ops.device(self._device0):
             get_next0 = self.getNext(dataset0)
         with ops.device(self._device1):
             get_next1 = self.getNext(dataset1)
         with ops.device(self._device2):
             get_next2 = self.getNext(dataset2)
         for _ in range(100):
             self.evaluate(get_next0())
             self.evaluate(get_next1())
             self.evaluate(get_next2())
Ejemplo n.º 16
0
  def __init__(self,
               dataset,
               input_workers,
               strategy,
               split_batch_by=None,
               input_context=None):
    """Distribute the dataset on all workers.

    If `split_batch_by` is not None, we "split" each batch of the dataset by
    `split_batch_by` value.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
        handle last partial batch.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
    """
    super(DistributedDataset, self).__init__(input_workers=input_workers)

    # We clone and shard the dataset on each worker. The current setup tries to
    # shard the dataset by files if possible so that each worker sees a
    # different subset of files. If that is not possible, will attempt to shard
    # the final input such that each worker will run the entire preprocessing
    # pipeline and only receive its own shard of the dataset.
    if split_batch_by:
      try:
        # pylint: disable=protected-access
        with ops.colocate_with(dataset._variant_tensor):
          dataset = distribute._RebatchDataset(dataset, split_batch_by)
          # Add a prefetch to pipeline rebatching for performance.
          # TODO(rachelim): Instead of inserting an extra prefetch stage here,
          # leverage static graph rewrites to insert _RebatchDataset before
          # the final `prefetch` if it exists.
          dataset = dataset.prefetch(split_batch_by)
      except errors.InvalidArgumentError as e:
        if "without encountering a batch" in str(e):
          six.reraise(
              ValueError,
              ValueError(
                  "Call the `batch` method on the input Dataset in order to be "
                  "able to split your input across {} replicas.\n Please "
                  "the tf.distribute.Strategy guide. {}".format(
                      split_batch_by, e)),
              sys.exc_info()[2])
        else:
          raise

    # TODO(b/138745411): Remove once stateful transformations are supported.
    options = dataset_ops.Options()
    options.experimental_distribute._make_stateless = True  # pylint: disable=protected-access
    dataset = dataset.with_options(options)

    self._cloned_datasets = []
    if input_context:
      # Between-graph where we rely on the input_context for sharding
      assert input_workers.num_workers == 1
      dataset = input_ops.auto_shard_dataset(dataset,
                                             input_context.num_input_pipelines,
                                             input_context.input_pipeline_id)
      self._cloned_datasets.append(dataset)
    else:
      replicated_ds = distribute.replicate(dataset,
                                           input_workers.worker_devices)
      for i, worker in enumerate(input_workers.worker_devices):
        with ops.device(worker):
          cloned_dataset = replicated_ds[worker]
          cloned_dataset = cloned_dataset.with_options(dataset.options())
          cloned_dataset = input_ops.auto_shard_dataset(
              cloned_dataset, len(input_workers.worker_devices), i)
          self._cloned_datasets.append(cloned_dataset)

    self._input_workers = input_workers
    self._strategy = strategy
    self._element_spec = _create_distributed_tensor_spec(self._strategy,
                                                         dataset.element_spec)  # pylint: disable=protected-access